use std::sync::Arc;
use arrow::record_batch::RecordBatch;
use arrow_array::{Array, Decimal128Array};
use datafusion::{execution::context::SessionContext, prelude::Expr};
use crate::supertable::{
error::QueryError,
handle::{Supertable, SupertableReader},
query::{
covered_agg::CoveredAggregateRewrite,
exec::{
fts_exec::register_bm25, hybrid_exec::register_hybrid_search,
match_exec::register_match, vector_exec::register_vector_search,
},
provider::{SupertableProvider, TABLE_NAME},
},
};
impl SupertableReader {
#[cfg(any(test, feature = "test-helpers"))]
pub fn query_sql(&self, sql: &str) -> Result<Vec<RecordBatch>, QueryError> {
let ctx = self.sql_session_context()?;
let sql = sql.to_owned();
let drive = async move {
let df = ctx
.sql(&sql)
.await
.map_err(|e| QueryError::Plan(e.to_string()))?;
df.collect()
.await
.map_err(|e| QueryError::Execute(e.to_string()))
};
self.block_on(drive)
}
fn sql_session_context(&self) -> Result<SessionContext, QueryError> {
let reader = Arc::new(self.clone());
let manifest = Arc::clone(reader.manifest());
let mut guard = self
.sql_session_cache()
.lock()
.expect("sql_session_cache mutex poisoned");
if let Some((cached, ctx)) = &*guard
&& Arc::ptr_eq(cached, &manifest)
{
return Ok(ctx.clone());
}
let store = Arc::clone(&self.options().store);
let disk_cache = self.options().disk_cache.as_ref().map(Arc::clone);
let scalar_schema = self.options().scalar_schema();
let provider = SupertableProvider::new(
Arc::clone(&scalar_schema),
Arc::clone(&manifest),
store,
disk_cache,
reader.tombstone_cache.clone(),
);
let ctx = SessionContext::new();
ctx.add_optimizer_rule(Arc::new(CoveredAggregateRewrite));
ctx.register_table(TABLE_NAME, Arc::new(provider))
.map_err(|e| QueryError::Plan(e.to_string()))?;
register_vector_search(&ctx, Arc::clone(&reader), Arc::clone(&scalar_schema));
register_bm25(&ctx, Arc::clone(&reader), Arc::clone(&scalar_schema));
register_match(&ctx, Arc::clone(&reader), Arc::clone(&scalar_schema));
register_hybrid_search(&ctx, Arc::clone(&reader), Arc::clone(&scalar_schema));
*guard = Some((Arc::clone(&manifest), ctx.clone()));
Ok(ctx)
}
pub(crate) fn scan_ids_matching(&self, expr: Expr) -> Result<Vec<i128>, QueryError> {
let ctx = self.sql_session_context()?;
let id_column = self.options().id_column.clone();
let drive = async move {
let df = ctx
.table(TABLE_NAME)
.await
.map_err(|e| QueryError::Plan(e.to_string()))?
.filter(expr)
.map_err(|e| QueryError::Plan(e.to_string()))?
.select_columns(&[id_column.as_str()])
.map_err(|e| QueryError::Plan(e.to_string()))?;
let batches = df
.collect()
.await
.map_err(|e| QueryError::Execute(e.to_string()))?;
extract_id_column(&batches)
};
self.block_on(drive)
}
}
impl Supertable {
pub(crate) fn register_into(
&self,
ctx: &SessionContext,
name: &str,
) -> Result<Arc<SupertableReader>, QueryError> {
self.ensure_fresh();
let reader = Arc::new(self.reader());
let manifest = Arc::clone(reader.manifest());
let store = Arc::clone(&self.options().store);
let disk_cache = self.options().disk_cache.as_ref().map(Arc::clone);
let scalar_schema = self.options().scalar_schema();
let provider = SupertableProvider::new(
scalar_schema,
manifest,
store,
disk_cache,
reader.tombstone_cache.clone(),
);
ctx.register_table(name, Arc::new(provider))
.map_err(|e| QueryError::Plan(e.to_string()))?;
Ok(reader)
}
}
fn extract_id_column(batches: &[RecordBatch]) -> Result<Vec<i128>, QueryError> {
let mut out: Vec<i128> = Vec::new();
for batch in batches {
if batch.num_columns() != 1 {
return Err(QueryError::Plan(format!(
"scan_ids_matching: expected 1-column batch, got {}",
batch.num_columns()
)));
}
let col = batch.column(0);
let arr = col
.as_any()
.downcast_ref::<Decimal128Array>()
.ok_or_else(|| {
QueryError::Plan("scan_ids_matching: _id column not Decimal128".into())
})?;
for i in 0..arr.len() {
if arr.is_null(i) {
continue;
}
out.push(arr.value(i));
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::{
Array, Decimal128Array, FixedSizeListArray, Float32Array, Int64Array, LargeStringArray,
RecordBatch, StringArray, StringViewArray,
};
use arrow_schema::{DataType, Field, Schema};
use crate::{
superfile::{
builder::{FtsConfig, VectorConfig},
vector::{distance::Metric, rerank_codec::RerankCodec},
},
supertable::{Supertable, SupertableOptions, error::QueryError},
test_helpers::default_tokenizer as tok,
};
fn schema_id_cat_title() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("category", DataType::LargeUtf8, false),
Field::new("title", DataType::LargeUtf8, false),
]))
}
fn options_id_cat_title() -> SupertableOptions {
let pool = Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(1)
.build()
.expect("rayon pool"),
);
SupertableOptions::new(
schema_id_cat_title(),
vec![FtsConfig {
column: "title".into(),
}],
vec![],
Some(tok()),
)
.expect("valid options")
.with_writer_pool(pool)
}
fn build_cat_batch(_start: u64, cats: &[&str], titles: &[&str]) -> RecordBatch {
assert_eq!(cats.len(), titles.len());
let cat_arr = LargeStringArray::from(cats.to_vec());
let title_arr = LargeStringArray::from(titles.to_vec());
RecordBatch::try_new(
schema_id_cat_title(),
vec![Arc::new(cat_arr), Arc::new(title_arr)],
)
.expect("build batch")
}
fn run_count(st: &Supertable, sql: &str) -> i64 {
let batches = st.reader().query_sql(sql).expect("query_sql ok");
assert!(!batches.is_empty(), "expected at least one result batch");
let n = batches[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.expect("count column is Int64");
n.value(0)
}
#[test]
fn query_sql_count_star_returns_zero_on_empty_supertable() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let n = run_count(&st, "SELECT COUNT(*) FROM supertable");
assert_eq!(n, 0);
}
#[test]
fn query_sql_count_star_returns_total_doc_count() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(
0,
&["rust", "rust", "python"],
&["a", "b", "c"],
))
.expect("append");
w.commit().expect("commit");
let n = run_count(&st, "SELECT COUNT(*) FROM supertable");
assert_eq!(n, 3);
}
#[test]
fn query_sql_session_cache_does_not_leak_consumer() {
let weak = {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(0, &["rust"], &["a"]))
.expect("append");
w.commit().expect("commit");
assert_eq!(run_count(&st, "SELECT COUNT(*) FROM supertable"), 1);
let weak = Arc::downgrade(st.inner());
drop(w);
drop(st);
weak
};
assert!(
weak.upgrade().is_none(),
"SQL session cache leaked the consumer — the \
inner -> SessionContext -> TVF -> reader -> inner cycle was not broken",
);
}
#[test]
fn query_sql_filter_predicate_applied_above_mem_table() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(
0,
&["rust", "rust", "python", "rust", "go"],
&["a", "b", "c", "d", "e"],
))
.expect("append");
w.commit().expect("commit");
let n = run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE category = 'rust'",
);
assert_eq!(n, 3);
}
#[test]
fn query_sql_group_by_returns_correct_per_category_counts() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(
0,
&["rust", "rust", "python", "rust", "python", "go"],
&["a", "b", "c", "d", "e", "f"],
))
.expect("append");
w.commit().expect("commit");
let batches = st
.reader()
.query_sql(
"SELECT category, COUNT(*) AS n FROM supertable \
GROUP BY category ORDER BY category",
)
.expect("group-by query");
assert_eq!(batches.len(), 1);
let cat_col = batches[0].column(0);
let counts = batches[0]
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.expect("count is Int64");
let extract = |i: usize| -> String {
if let Some(a) = cat_col.as_any().downcast_ref::<LargeStringArray>() {
a.value(i).to_string()
} else if let Some(a) = cat_col.as_any().downcast_ref::<StringArray>() {
a.value(i).to_string()
} else if let Some(a) = cat_col.as_any().downcast_ref::<StringViewArray>() {
a.value(i).to_string()
} else {
panic!("unexpected category column type: {:?}", cat_col.data_type())
}
};
let mut got: Vec<(String, i64)> = (0..cat_col.len())
.map(|i| (extract(i), counts.value(i)))
.collect();
got.sort();
assert_eq!(
got,
vec![
("go".to_string(), 1),
("python".to_string(), 2),
("rust".to_string(), 3),
]
);
}
#[test]
fn query_sql_scans_across_multiple_superfiles() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(0, &["rust", "rust"], &["a", "b"]))
.expect("a1");
w.commit().expect("c1");
w.append(&build_cat_batch(10, &["python"], &["c"]))
.expect("a2");
w.commit().expect("c2");
w.append(&build_cat_batch(20, &["rust", "go"], &["d", "e"]))
.expect("a3");
w.commit().expect("c3");
assert_eq!(st.reader().n_superfiles(), 3);
let n_total = run_count(&st, "SELECT COUNT(*) FROM supertable");
assert_eq!(n_total, 5);
let n_rust = run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE category = 'rust'",
);
assert_eq!(n_rust, 3);
}
#[test]
fn query_sql_equality_on_fts_column_across_superfiles_is_correct() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(0, &["x"], &["alpha"]))
.expect("a1");
w.commit().expect("c1");
w.append(&build_cat_batch(10, &["y"], &["bravo"]))
.expect("a2");
w.commit().expect("c2");
w.append(&build_cat_batch(20, &["z"], &["charlie"]))
.expect("a3");
w.commit().expect("c3");
assert_eq!(st.reader().n_superfiles(), 3);
assert_eq!(
run_count(&st, "SELECT COUNT(*) FROM supertable WHERE title = 'bravo'"),
1
);
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE title = 'nonexistent'"
),
0
);
}
#[test]
fn query_sql_multiword_equality_on_fts_column_is_correct() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(0, &["lang"], &["rust async runtime"]))
.expect("a1");
w.commit().expect("c1");
w.append(&build_cat_batch(10, &["lang"], &["python data science"]))
.expect("a2");
w.commit().expect("c2");
assert_eq!(st.reader().n_superfiles(), 2);
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE title = 'rust async runtime'"
),
1
);
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE title = 'rust async'"
),
0
);
}
#[test]
fn query_sql_fts_equality_superset_is_narrowed_to_exact_match() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(
0,
&["x", "y"],
&["rust async", "rust async runtime"],
))
.expect("append");
w.commit().expect("commit");
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE title = 'rust async'",
),
1,
);
let batches = st
.reader()
.query_sql("SELECT title FROM supertable WHERE title = 'rust async'")
.expect("query");
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 1);
}
#[test]
fn query_sql_fts_or_and_in_are_exact() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(
0,
&["rust", "python", "rust", "go"],
&["alpha", "beta", "gamma", "delta"],
))
.expect("append");
w.commit().expect("commit");
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE title = 'alpha' OR title = 'beta'",
),
2,
);
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable \
WHERE title = 'alpha' AND category = 'rust'",
),
1,
);
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable \
WHERE title = 'alpha' AND category = 'python'",
),
0,
);
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE title IN ('alpha', 'delta', 'zzz')",
),
2,
);
}
#[test]
fn query_sql_not_predicates_are_exact() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(
0,
&["rust", "python", "rust", "go"],
&["alpha", "beta", "alpha", "delta"],
))
.expect("append");
w.commit().expect("commit");
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE NOT (title = 'alpha')",
),
2,
);
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE title != 'alpha'"
),
2,
);
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable \
WHERE title = 'alpha' AND category != 'rust'",
),
0,
);
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable \
WHERE title = 'alpha' AND category != 'python'",
),
2,
);
}
#[test]
fn query_sql_or_with_non_fts_branch_matches_full_scan() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(
0,
&["rust", "python", "go", "go"],
&["alpha", "beta", "gamma", "delta"],
))
.expect("append");
w.commit().expect("commit");
assert_eq!(
run_count(
&st,
"SELECT COUNT(*) FROM supertable WHERE title = 'alpha' OR category = 'go'",
),
3,
);
}
#[test]
fn query_sql_select_orders_ids_across_superfiles() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(100, &["a", "b"], &["t1", "t2"]))
.expect("a1");
w.commit().expect("c1");
w.append(&build_cat_batch(200, &["c"], &["t3"]))
.expect("a2");
w.commit().expect("c2");
let batches = st
.reader()
.query_sql("SELECT _id FROM supertable ORDER BY _id")
.expect("query");
let ids: Vec<i128> = batches
.iter()
.flat_map(|b| {
let a = b
.column(0)
.as_any()
.downcast_ref::<Decimal128Array>()
.expect("_id is Decimal128");
(0..a.len()).map(|i| a.value(i)).collect::<Vec<_>>()
})
.collect();
assert_eq!(ids.len(), 3);
for w in ids.windows(2) {
assert!(w[0] < w[1], "expected strictly increasing _id");
}
}
#[test]
fn query_sql_select_star_exposes_only_user_columns_plus_id() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(0, &["x"], &["t"])).expect("a");
w.commit().expect("c");
let batches = st
.reader()
.query_sql("SELECT * FROM supertable LIMIT 1")
.expect("query");
let schema = batches[0].schema();
let names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(names, vec!["_id", "category", "title"]);
}
#[test]
fn query_sql_runtime_is_cached_across_calls() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_cat_batch(0, &["x"], &["t"])).expect("a");
w.commit().expect("c");
for _ in 0..3 {
let n = run_count(&st, "SELECT COUNT(*) FROM supertable");
assert_eq!(n, 1);
}
}
#[test]
fn query_sql_invalid_sql_returns_plan_error() {
let st = Supertable::create(options_id_cat_title()).expect("create");
let err = st
.reader()
.query_sql("SELECT NOT_A_REAL_FN(*) FROM supertable")
.expect_err("expected a plan error");
assert!(
matches!(err, QueryError::Plan(_)),
"expected Plan variant; got {err:?}"
);
}
fn schema_with_vector(dim: usize) -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("title", DataType::LargeUtf8, false),
Field::new(
"emb",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
),
false,
),
]))
}
fn options_with_vector(dim: usize) -> SupertableOptions {
let pool = Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(1)
.build()
.expect("rayon pool"),
);
SupertableOptions::new(
schema_with_vector(dim),
vec![FtsConfig {
column: "title".into(),
}],
vec![VectorConfig {
column: "emb".into(),
dim,
n_cent: 4,
rot_seed: 0,
metric: Metric::Cosine,
rerank_codec: RerankCodec::Fp32,
}],
Some(tok()),
)
.expect("valid options")
.with_writer_pool(pool)
}
fn build_vector_batch(_start: u64, n: usize, dim: usize) -> RecordBatch {
let titles = LargeStringArray::from((0..n).map(|i| format!("doc {i}")).collect::<Vec<_>>());
let mut flat = Vec::<f32>::with_capacity(n * dim);
for i in 0..n {
for d in 0..dim {
flat.push(((i + d) as f32) / 100.0);
}
}
let item_field = Arc::new(Field::new("item", DataType::Float32, true));
let values = Float32Array::from(flat);
let emb = FixedSizeListArray::try_new(
item_field,
dim as i32,
Arc::new(values) as Arc<dyn Array>,
None,
)
.expect("FixedSizeList build");
RecordBatch::try_new(
schema_with_vector(dim),
vec![Arc::new(titles), Arc::new(emb)],
)
.expect("build batch")
}
#[test]
fn query_sql_hides_vector_columns_from_sql_surface() {
let st = Supertable::create(options_with_vector(16)).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_vector_batch(0, 8, 16)).expect("append");
w.commit().expect("commit");
let batches = st
.reader()
.query_sql("SELECT * FROM supertable LIMIT 1")
.expect("query");
let schema = batches[0].schema();
let names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(names, vec!["_id", "title"]);
}
#[test]
fn query_sql_referencing_vector_column_returns_plan_error() {
let st = Supertable::create(options_with_vector(16)).expect("create");
let mut w = st.writer().expect("writer");
w.append(&build_vector_batch(0, 8, 16)).expect("append");
w.commit().expect("commit");
let err = st
.reader()
.query_sql("SELECT emb FROM supertable")
.expect_err("vector column should not be in the SQL schema");
assert!(
matches!(err, QueryError::Plan(_)),
"expected Plan variant; got {err:?}"
);
}
}