#![deny(clippy::unwrap_used)]
use std::sync::Arc;
use arrow_array::{
ArrayRef, FixedSizeListArray, Float32Array, Int64Array, LargeStringArray, RecordBatch,
};
use arrow_schema::{DataType, Field, Schema};
use infino::{
superfile::builder::FtsConfig,
supertable::{Supertable, SupertableOptions},
test_helpers::{default_tokenizer, default_vector_config},
};
const DIM: usize = 16;
const VECTOR_ROT_SEED: u64 = 13;
const DOCS_PER_COMMIT: usize = 8;
const SURFACE_TOP_K: usize = 32;
fn fixed_list_f32(dim: usize) -> DataType {
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
)
}
fn options_title_rating_emb() -> SupertableOptions {
let schema = Arc::new(Schema::new(vec![
Field::new("title", DataType::LargeUtf8, false),
Field::new("rating", DataType::Int64, false),
Field::new("emb", fixed_list_f32(DIM), false),
]));
SupertableOptions::new(
schema,
vec![FtsConfig {
column: "title".into(),
}],
vec![default_vector_config("emb", VECTOR_ROT_SEED)],
Some(default_tokenizer()),
)
.expect("valid options")
}
fn build_batch(titles: &[&str], base: usize, schema: Arc<Schema>) -> RecordBatch {
let n = titles.len();
let title_arr = LargeStringArray::from(titles.to_vec());
let ratings: Vec<i64> = (0..n).map(|i| (base + i) as i64).collect();
let mut flat = Vec::<f32>::with_capacity(n * DIM);
for i in 0..n {
let active = (base + i) % DIM;
for d in 0..DIM {
flat.push(if d == active { 1.0 } else { 0.0 });
}
}
let fsl = FixedSizeListArray::try_new(
Arc::new(Field::new("item", DataType::Float32, true)),
DIM as i32,
Arc::new(Float32Array::from(flat)) as ArrayRef,
None,
)
.expect("FSL");
RecordBatch::try_new(
schema,
vec![
Arc::new(title_arr),
Arc::new(Int64Array::from(ratings)),
Arc::new(fsl),
],
)
.expect("batch")
}
fn demo_table() -> Supertable {
let st = Supertable::create(options_title_rating_emb()).expect("create");
let schema = st.options().schema.clone();
let mut w = st.writer().expect("writer");
w.append(&build_batch(
&[
"rust async",
"python data",
"java spring",
"go rust",
"ruby rails",
"scala akka",
"kotlin flow",
"rust systems",
],
0,
schema.clone(),
))
.expect("append seg1");
w.commit().expect("commit seg1");
w.append(&build_batch(
&[
"swift ui",
"rust web",
"elixir otp",
"haskell lazy",
"rust embedded",
"perl regex",
"lua script",
"rust async runtime",
],
DOCS_PER_COMMIT,
schema,
))
.expect("append seg2");
w.commit().expect("commit seg2");
drop(w);
st
}
fn csv_one_hot(active: usize) -> String {
(0..DIM)
.map(|d| if d == active { "1" } else { "0" })
.collect::<Vec<_>>()
.join(",")
}
fn row_count(batches: &[RecordBatch]) -> usize {
batches.iter().map(RecordBatch::num_rows).sum()
}
fn assert_all_error(st: &Supertable, label: &str, queries: &[String]) {
for q in queries {
let result = st.reader().query_sql(q);
assert!(
result.is_err(),
"[{label}] expected an error, query unexpectedly succeeded: {q}"
);
}
}
#[test]
fn bm25_search_wrong_arity_errors() {
let st = demo_table();
assert_all_error(
&st,
"bm25_search arity",
&[
"SELECT _id FROM bm25_search('title')".to_string(),
"SELECT _id FROM bm25_search('title', 'rust')".to_string(),
"SELECT _id FROM bm25_search('title', 'rust', 8, 'and', 'extra')".to_string(),
],
);
}
#[test]
fn bm25_search_wrong_arg_types_error() {
let st = demo_table();
assert_all_error(
&st,
"bm25_search types",
&[
"SELECT _id FROM bm25_search(42, 'rust', 8)".to_string(),
"SELECT _id FROM bm25_search('title', 'rust', 'eight')".to_string(),
"SELECT _id FROM bm25_search('title', 'rust', 8, 'maybe')".to_string(),
],
);
}
#[test]
fn bm25_search_unknown_column_errors() {
let st = demo_table();
assert_all_error(
&st,
"bm25_search unknown column",
&["SELECT _id FROM bm25_search('nonexistent', 'rust', 8)".to_string()],
);
}
#[test]
fn bm25_search_prefix_wrong_arity_errors() {
let st = demo_table();
assert_all_error(
&st,
"bm25_search_prefix arity",
&[
"SELECT _id FROM bm25_search_prefix('title', 'rus')".to_string(),
"SELECT _id FROM bm25_search_prefix('title', 'rus', 8, 'and')".to_string(),
],
);
}
#[test]
fn bm25_search_prefix_wrong_arg_types_error() {
let st = demo_table();
assert_all_error(
&st,
"bm25_search_prefix types",
&[
"SELECT _id FROM bm25_search_prefix('title', 7, 8)".to_string(),
"SELECT _id FROM bm25_search_prefix('title', 'rus', 'big')".to_string(),
],
);
}
#[test]
fn vector_search_wrong_arity_errors() {
let st = demo_table();
let qv = csv_one_hot(0);
assert_all_error(
&st,
"vector_search arity",
&[
format!("SELECT _id FROM vector_search('emb', '{qv}')"),
format!("SELECT _id FROM vector_search('emb', '{qv}', 8, 'extra')"),
],
);
}
#[test]
fn vector_search_wrong_arg_types_error() {
let st = demo_table();
let qv = csv_one_hot(0);
assert_all_error(
&st,
"vector_search types",
&[
format!("SELECT _id FROM vector_search(1, '{qv}', 8)"),
"SELECT _id FROM vector_search('emb', '1,two,3', 8)".to_string(),
"SELECT _id FROM vector_search('emb', '', 8)".to_string(),
format!("SELECT _id FROM vector_search('emb', '{qv}', 'k')"),
],
);
}
#[test]
fn vector_search_unknown_column_errors() {
let st = demo_table();
let qv = csv_one_hot(0);
assert_all_error(
&st,
"vector_search unknown column",
&[format!(
"SELECT _id FROM vector_search('missing', '{qv}', 8)"
)],
);
}
#[test]
fn match_tvfs_wrong_arity_error() {
let st = demo_table();
assert_all_error(
&st,
"match arity",
&[
"SELECT _id FROM token_match('title')".to_string(),
"SELECT _id FROM token_match('title', 'rust', 'and', 'extra')".to_string(),
"SELECT _id FROM exact_match('title')".to_string(),
"SELECT _id FROM exact_match('title', 'rust async', 'extra')".to_string(),
],
);
}
#[test]
fn match_tvfs_wrong_arg_types_error() {
let st = demo_table();
assert_all_error(
&st,
"match types",
&[
"SELECT _id FROM token_match(5, 'rust')".to_string(),
"SELECT _id FROM token_match('title', 'rust', 'nope')".to_string(),
"SELECT _id FROM exact_match('title', 99)".to_string(),
],
);
}
#[test]
fn hybrid_search_wrong_arity_errors() {
let st = demo_table();
let qv = csv_one_hot(0);
assert_all_error(
&st,
"hybrid_search arity",
&[
format!("SELECT _id FROM hybrid_search('title', 'rust', 'emb', '{qv}')"),
format!("SELECT _id FROM hybrid_search('title', 'rust', 'emb', '{qv}', 8, 'extra')"),
],
);
}
#[test]
fn hybrid_search_wrong_arg_types_error() {
let st = demo_table();
let qv = csv_one_hot(0);
assert_all_error(
&st,
"hybrid_search types",
&[
format!("SELECT _id FROM hybrid_search(0, 'rust', 'emb', '{qv}', 8)"),
format!("SELECT _id FROM hybrid_search('title', 'rust', 1, '{qv}', 8)"),
"SELECT _id FROM hybrid_search('title', 'rust', 'emb', 'x,y', 8)".to_string(),
format!("SELECT _id FROM hybrid_search('title', 'rust', 'emb', '{qv}', 'k')"),
],
);
}
#[test]
fn varied_projections_succeed_for_ranked_tvfs() {
let st = demo_table();
let qv = csv_one_hot(0);
for q in [
format!("SELECT _id FROM bm25_search('title', 'rust', {SURFACE_TOP_K})"),
format!("SELECT _id FROM vector_search('emb', '{qv}', {SURFACE_TOP_K})"),
] {
let b = st.reader().query_sql(&q).expect("id-only query");
assert_eq!(b[0].schema().field(0).name(), "_id", "{q}");
assert_eq!(
b[0].num_columns(),
1,
"id-only must project one column: {q}"
);
}
let mixed = st
.reader()
.query_sql(&format!(
"SELECT rating, score FROM bm25_search('title', 'rust', {SURFACE_TOP_K})"
))
.expect("mixed projection");
assert!(mixed[0].schema().index_of("rating").is_ok());
assert!(mixed[0].schema().index_of("score").is_ok());
let star = st
.reader()
.query_sql(&format!(
"SELECT * FROM vector_search('emb', '{qv}', {SURFACE_TOP_K})"
))
.expect("star projection");
assert!(star[0].schema().index_of("title").is_ok());
assert!(star[0].schema().index_of("rating").is_ok());
assert!(star[0].schema().index_of("score").is_ok());
assert!(
star[0].schema().index_of("emb").is_err(),
"vector column must never reach SQL"
);
}
#[test]
fn zero_k_yields_empty_result() {
let st = demo_table();
let qv = csv_one_hot(0);
for q in [
"SELECT _id, score FROM bm25_search('title', 'rust', 0)".to_string(),
format!("SELECT _id, score FROM vector_search('emb', '{qv}', 0)"),
format!("SELECT * FROM hybrid_search('title', 'rust', 'emb', '{qv}', 0)"),
] {
let b = st.reader().query_sql(&q).expect("k=0 query");
assert_eq!(row_count(&b), 0, "k=0 must produce no rows: {q}");
}
}
#[test]
fn empty_result_queries_return_no_rows() {
let st = demo_table();
for q in [
format!("SELECT _id FROM bm25_search('title', 'zzzznomatch', {SURFACE_TOP_K})"),
"SELECT _id FROM token_match('title', 'zzzznomatch')".to_string(),
"SELECT _id FROM exact_match('title', 'no such exact title')".to_string(),
] {
let b = st.reader().query_sql(&q).expect("empty-result query");
assert_eq!(row_count(&b), 0, "expected no matches: {q}");
}
}
#[test]
fn malformed_base_table_sql_errors() {
let st = demo_table();
assert_all_error(
&st,
"base-table malformed",
&[
"SELECT * FROM supertable WHERE".to_string(),
"SELCT * FROM supertable".to_string(),
"SELECT no_such_column FROM supertable".to_string(),
"SELECT * FROM no_such_table".to_string(),
],
);
}