#![deny(clippy::unwrap_used)]
use std::sync::Arc;
use arrow_array::{
Array, ArrayRef, FixedSizeListArray, Float32Array, Int64Array, LargeStringArray, RecordBatch,
};
use arrow_schema::{DataType, Field, Schema};
use infino::{
superfile::builder::FtsConfig,
supertable::{Supertable, SupertableOptions},
test_helpers::{default_tokenizer, default_vector_config},
};
const DIM: usize = 16;
const VECTOR_ROT_SEED: u64 = 11;
const DOCS_PER_COMMIT: usize = 8;
const SURFACE_TOP_K: usize = 32;
fn fixed_list_f32(dim: usize) -> DataType {
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
)
}
fn options_title_rating_emb() -> SupertableOptions {
let schema = Arc::new(Schema::new(vec![
Field::new("title", DataType::LargeUtf8, false),
Field::new("rating", DataType::Int64, false),
Field::new("emb", fixed_list_f32(DIM), false),
]));
SupertableOptions::new(
schema,
vec![FtsConfig {
column: "title".into(),
}],
vec![default_vector_config("emb", VECTOR_ROT_SEED)],
Some(default_tokenizer()),
)
.expect("valid options")
}
fn build_batch(titles: &[&str], base: usize, schema: Arc<Schema>) -> RecordBatch {
let n = titles.len();
let title_arr = LargeStringArray::from(titles.to_vec());
let ratings: Vec<i64> = (0..n).map(|i| (base + i) as i64).collect();
let mut flat = Vec::<f32>::with_capacity(n * DIM);
for i in 0..n {
let active = (base + i) % DIM;
for d in 0..DIM {
flat.push(if d == active { 1.0 } else { 0.0 });
}
}
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
DIM as i32,
Arc::new(Float32Array::from(flat)) as ArrayRef,
None,
)
.expect("FSL");
RecordBatch::try_new(
schema,
vec![
Arc::new(title_arr),
Arc::new(Int64Array::from(ratings)),
Arc::new(fsl),
],
)
.expect("batch")
}
fn demo_table() -> Supertable {
let st = Supertable::create(options_title_rating_emb()).expect("create");
let schema = st.options().schema.clone();
let mut w = st.writer().expect("writer");
w.append(&build_batch(
&[
"rust async",
"python data",
"java spring",
"go rust",
"ruby rails",
"scala akka",
"kotlin flow",
"rust systems",
],
0,
schema.clone(),
))
.expect("append seg1");
w.commit().expect("commit seg1");
w.append(&build_batch(
&[
"swift ui",
"rust web",
"elixir otp",
"haskell lazy",
"rust embedded",
"perl regex",
"lua script",
"rust async runtime",
],
DOCS_PER_COMMIT,
schema.clone(),
))
.expect("append seg2");
w.commit().expect("commit seg2");
drop(w);
st
}
fn csv_one_hot(active: usize) -> String {
(0..DIM)
.map(|d| if d == active { "1" } else { "0" })
.collect::<Vec<_>>()
.join(",")
}
fn explain_text(st: &Supertable, sql: &str) -> String {
let batches = st
.reader()
.query_sql(&format!("EXPLAIN {sql}"))
.expect("explain");
let mut out = String::new();
for batch in &batches {
for column in batch.columns() {
if let Some(strings) = column.as_any().downcast_ref::<arrow_array::StringArray>() {
for i in 0..strings.len() {
if !strings.is_null(i) {
out.push_str(strings.value(i));
out.push('\n');
}
}
}
}
}
out
}
fn row_count(batches: &[RecordBatch]) -> usize {
batches.iter().map(RecordBatch::num_rows).sum()
}
#[test]
fn explain_formats_every_search_tvf_plan() {
let st = demo_table();
let qv = csv_one_hot(0);
let queries = [
"SELECT _id, score FROM bm25_search('title', 'rust', 8)".to_string(),
"SELECT _id, score FROM bm25_search_prefix('title', 'rus', 8)".to_string(),
format!("SELECT _id, score FROM vector_search('emb', '{qv}', 8)"),
"SELECT _id FROM token_match('title', 'rust')".to_string(),
"SELECT _id FROM token_match('title', 'rust systems', 'and')".to_string(),
"SELECT _id FROM exact_match('title', 'rust async')".to_string(),
format!("SELECT _id, score FROM hybrid_search('title', 'rust', 'emb', '{qv}', 8)"),
];
for q in &queries {
let plan = explain_text(&st, q);
assert!(
!plan.trim().is_empty(),
"EXPLAIN produced no plan text for: {q}"
);
}
}
#[test]
fn explain_analyze_runs_and_formats_metrics() {
let st = demo_table();
let qv = csv_one_hot(0);
for q in [
"SELECT _id FROM bm25_search('title', 'rust', 8)".to_string(),
format!("SELECT _id FROM vector_search('emb', '{qv}', 8)"),
format!("SELECT _id FROM hybrid_search('title', 'rust', 'emb', '{qv}', 8)"),
] {
let plan = explain_text(&st, &format!("ANALYZE {q}"));
assert!(!plan.trim().is_empty(), "EXPLAIN ANALYZE empty for: {q}");
}
}
#[test]
fn star_projection_materializes_scalar_columns_for_every_tvf() {
let st = demo_table();
let qv = csv_one_hot(0);
let queries = [
format!("SELECT * FROM bm25_search('title', 'rust', {SURFACE_TOP_K})"),
format!("SELECT * FROM bm25_search_prefix('title', 'rus', {SURFACE_TOP_K})"),
format!("SELECT * FROM vector_search('emb', '{qv}', {SURFACE_TOP_K})"),
"SELECT * FROM token_match('title', 'rust')".to_string(),
"SELECT * FROM exact_match('title', 'rust async')".to_string(),
format!("SELECT * FROM hybrid_search('title', 'rust', 'emb', '{qv}', {SURFACE_TOP_K})"),
];
for q in &queries {
let batches = st.reader().query_sql(q).expect("star query");
let b = &batches[0];
assert_eq!(b.schema().field(0).name(), "_id", "{q}");
assert!(b.schema().index_of("title").is_ok(), "{q}");
assert!(b.schema().index_of("rating").is_ok(), "{q}");
assert!(b.schema().index_of("score").is_ok(), "{q}");
assert!(
b.schema().index_of("emb").is_err(),
"vector column must never reach SQL: {q}"
);
}
}
#[test]
fn order_by_and_limit_wrap_the_tvf_scan() {
let st = demo_table();
let batches = st
.reader()
.query_sql(&format!(
"SELECT _id, score FROM bm25_search('title', 'rust', {SURFACE_TOP_K}) \
ORDER BY score DESC LIMIT 3"
))
.expect("order+limit query");
assert!(row_count(&batches) <= 3, "LIMIT 3 must cap the row count");
assert!(row_count(&batches) >= 1, "'rust' should match something");
}
#[test]
fn filter_on_materialized_scalar_above_tvf() {
let st = demo_table();
let batches = st
.reader()
.query_sql(&format!(
"SELECT _id, rating FROM bm25_search('title', 'rust', {SURFACE_TOP_K}) \
WHERE rating >= 8"
))
.expect("filter query");
let idx = batches[0].schema().index_of("rating").expect("rating col");
for b in &batches {
let col = b
.column(idx)
.as_any()
.downcast_ref::<Int64Array>()
.expect("i64 rating");
for i in 0..col.len() {
assert!(col.value(i) >= 8, "WHERE rating >= 8 leaked a row");
}
}
}
#[test]
fn base_table_scan_supports_aggregates_and_projection() {
let st = demo_table();
let total = st
.reader()
.query_sql("SELECT COUNT(*) AS n FROM supertable")
.expect("count");
let n = total[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.expect("i64 count")
.value(0);
assert_eq!(n, (DOCS_PER_COMMIT * 2) as i64, "all rows visible");
let sum = st
.reader()
.query_sql("SELECT SUM(rating) AS s FROM supertable")
.expect("sum");
let s = sum[0]
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.expect("i64 sum")
.value(0);
let expected: i64 = (0..(DOCS_PER_COMMIT * 2) as i64).sum();
assert_eq!(s, expected, "SUM(rating) over the whole table");
let top = st
.reader()
.query_sql("SELECT title, rating FROM supertable ORDER BY rating DESC LIMIT 2")
.expect("ordered projection");
assert!(row_count(&top) <= 2, "LIMIT 2 caps the projection");
}
#[test]
fn explain_base_table_scan() {
let st = demo_table();
let plan = explain_text(&st, "SELECT title FROM supertable WHERE rating > 4");
assert!(!plan.trim().is_empty(), "base-table EXPLAIN was empty");
}