use std::sync::Arc;
use anyhow::{Context, Result, anyhow};
use arrow_array::builder::{FixedSizeListBuilder, Float32Builder, ListBuilder};
use arrow_array::{Array, Float32Array, RecordBatch, UInt64Array};
use futures::TryStreamExt;
use lancedb::DistanceType;
use lancedb::query::{ExecutableQuery, QueryBase};
use tempfile::TempDir;
const DIM: i32 = 3;
struct Doc {
id: u64,
tokens: Vec<[f32; 3]>,
}
#[tokio::main]
async fn main() -> Result<()> {
let e0 = [1.0_f32, 0.0, 0.0];
let e1 = [0.0_f32, 1.0, 0.0];
let e2 = [0.0_f32, 0.0, 1.0];
let query: Vec<[f32; 3]> = vec![e0, e1];
let docs = vec![
Doc {
id: 1,
tokens: vec![e0, e1],
}, Doc {
id: 2,
tokens: vec![e0, e2, e1],
}, Doc {
id: 3,
tokens: vec![e0],
}, Doc {
id: 4,
tokens: vec![e2],
}, ];
println!("Expected MaxSim (dot): id1=2.0, id2=2.0, id3=1.0, id4=0.0");
println!("Expected ranking: [{{1,2}} (tie)] > 3 > 4\n");
let tmp = TempDir::new().context("create temp dir")?;
let uri = tmp.path().to_str().context("temp path is not utf-8")?;
let conn = lancedb::connect(uri).execute().await.context("connect")?;
let batch = build_batch(&docs)?;
let table = conn
.create_table("docs", vec![batch])
.execute()
.await
.context("create_table")?;
println!(
"Stored {} docs in a List<FixedSizeList<Float32,{DIM}>> column.\n",
docs.len()
);
for (metric, assert_maxsim) in [
(DistanceType::Dot, true),
(DistanceType::Cosine, true),
(DistanceType::L2, false),
] {
let ranked = run_maxsim_query(&table, &query, metric)
.await
.with_context(|| format!("multivector query with {metric:?}"))?;
let order: Vec<u64> = ranked.iter().map(|(id, _)| *id).collect();
println!("{metric:?}: returned order (best->worst) = {order:?}");
for (id, dist) in &ranked {
println!(" id={id} _distance={dist:.4}");
}
if assert_maxsim {
verify_maxsim_order(metric, &order)?;
println!(" OK: {metric:?} reproduces MaxSim ranking.\n");
} else {
println!(" (reference only, no assertion)\n");
}
}
println!(
"PROBE PASSED: native multi-vector storage + MaxSim retrieval works \
on lance 7.0.0 / lancedb 0.30.0 via the production Table API, with \
no vector index required."
);
Ok(())
}
fn build_batch(docs: &[Doc]) -> Result<RecordBatch> {
let ids = UInt64Array::from(docs.iter().map(|d| d.id).collect::<Vec<_>>());
let mut vectors = ListBuilder::new(FixedSizeListBuilder::new(Float32Builder::new(), DIM));
for doc in docs {
for tok in &doc.tokens {
vectors.values().values().append_slice(tok);
vectors.values().append(true);
}
vectors.append(true);
}
let vectors = vectors.finish();
RecordBatch::try_from_iter(vec![
("id", Arc::new(ids) as Arc<dyn Array>),
("vector", Arc::new(vectors) as Arc<dyn Array>),
])
.context("assemble record batch")
}
async fn run_maxsim_query(
table: &lancedb::Table,
query: &[[f32; 3]],
metric: DistanceType,
) -> Result<Vec<(u64, f32)>> {
let (first, rest) = query.split_first().ok_or_else(|| anyhow!("empty query"))?;
let mut vq = table.vector_search(first.to_vec())?;
for tok in rest {
vq = vq.add_query_vector(tok.to_vec())?;
}
let batches = vq
.column("vector")
.distance_type(metric)
.limit(docs_count())
.execute()
.await?
.try_collect::<Vec<RecordBatch>>()
.await?;
let mut out = Vec::new();
for batch in &batches {
let ids = batch
.column_by_name("id")
.and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
.ok_or_else(|| anyhow!("missing/typed-wrong id column"))?;
let dist = batch
.column_by_name("_distance")
.and_then(|c| c.as_any().downcast_ref::<Float32Array>())
.ok_or_else(|| anyhow!("missing/typed-wrong _distance column"))?;
for row in 0..batch.num_rows() {
out.push((ids.value(row), dist.value(row)));
}
}
out.sort_by(|a, b| a.1.total_cmp(&b.1));
Ok(out)
}
fn docs_count() -> usize {
4
}
fn verify_maxsim_order(metric: DistanceType, order: &[u64]) -> Result<()> {
if order.len() != 4 {
return Err(anyhow!(
"{metric:?}: expected 4 results, got {}",
order.len()
));
}
let top_two: std::collections::BTreeSet<u64> = order[..2].iter().copied().collect();
let expected_top: std::collections::BTreeSet<u64> = [1, 2].into_iter().collect();
if top_two != expected_top {
return Err(anyhow!(
"{metric:?}: top-2 should be {{1,2}} (MaxSim=2.0), got {top_two:?}"
));
}
if order[2] != 3 {
return Err(anyhow!(
"{metric:?}: rank-3 should be id 3 (MaxSim=1.0), got {}",
order[2]
));
}
if order[3] != 4 {
return Err(anyhow!(
"{metric:?}: rank-4 should be id 4 (MaxSim=0.0), got {}",
order[3]
));
}
Ok(())
}