use std::sync::Arc;
use rayon::prelude::*;
use tracing::{debug, error};
use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
use ailake_core::{AilakeError, AilakeResult, EmbeddingModelInfo, RowId, VectorMetric};
use ailake_file::AilakeFileReader;
use ailake_index::AnyIndex;
use ailake_store::Store;
use ailake_vec::exact_distance;
use arrow_array::RecordBatch;
use bytes::Bytes;
use crate::pruner::VectorPruner;
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub top_k: usize,
pub ef_search: usize,
pub pruning_threshold: f32,
pub rerank_factor: Option<usize>,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
top_k: 10,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: None,
}
}
}
impl SearchConfig {
pub fn with_pruning(mut self, threshold: f32) -> Self {
self.pruning_threshold = threshold;
self
}
pub fn with_reranking(mut self, factor: usize) -> Self {
self.rerank_factor = Some(factor);
self
}
}
#[derive(Debug)]
pub struct SearchResult {
pub row_id: RowId,
pub distance: f32,
pub file_path: String,
}
pub async fn search(
table: &TableIdent,
query: &[f32],
config: SearchConfig,
vector_column: &str,
dim: u32,
catalog: Arc<dyn CatalogProvider>,
store: Arc<dyn Store>,
) -> AilakeResult<Vec<SearchResult>> {
let all_files = catalog.list_files(table, None).await?;
let table_meta = catalog.load_table(table).await?;
let primary_col = table_meta
.properties
.get("ailake.vector-column")
.map(String::as_str)
.unwrap_or("");
let stored_dim_key = if vector_column == primary_col {
"ailake.vector-dim".to_string()
} else {
format!("ailake.dim-{vector_column}")
};
if let Some(table_dim_str) = table_meta.properties.get(&stored_dim_key) {
if let Ok(table_dim) = table_dim_str.parse::<u32>() {
let query_dim = query.len() as u32;
if query_dim != table_dim {
let table_model = table_meta
.properties
.get(EmbeddingModelInfo::property_key())
.cloned()
.unwrap_or_else(|| format!("dim={}", table_dim));
return Err(AilakeError::ModelMismatch {
table_model,
table_dim,
batch_model: format!("query dim={}", query_dim),
batch_dim: query_dim,
});
}
}
}
let metric_key = if vector_column == primary_col {
"ailake.vector-metric".to_string()
} else {
format!("ailake.metric-{vector_column}")
};
let metric = parse_metric(
table_meta
.properties
.get(&metric_key)
.or_else(|| table_meta.properties.get("ailake.vector-metric"))
.map(String::as_str)
.unwrap_or("cosine"),
);
let total_files = all_files.len();
let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
debug!(
"ailake: geometric pruning — {}/{} files survive (threshold={})",
surviving_files.len(),
total_files,
config.pruning_threshold
);
let candidate_k = match config.rerank_factor {
Some(factor) => config.top_k * factor,
None => config.top_k,
};
let mut all_results: Vec<SearchResult> = Vec::new();
for file_entry in &surviving_files {
let file_bytes: Bytes = store.get(&file_entry.path).await?;
let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
debug!(
"ailake: flat scan fallback for {} (index_status={:?})",
file_entry.path, file_entry.index_status
);
let (_, raw_vectors) = reader.read_parquet()?;
for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
all_results.push(SearchResult {
row_id,
distance,
file_path: file_entry.path.clone(),
});
}
continue;
}
let index = reader.load_any_index_for_column(vector_column)?;
let local_results = index.search(query, candidate_k, config.ef_search);
if config.rerank_factor.is_some() {
let (_, raw_vectors) = reader.read_parquet()?;
for (row_id, _approx_dist) in local_results {
let idx = row_id.as_u64() as usize;
let exact_dist = match raw_vectors.get(idx) {
Some(v) => exact_distance(metric, query, v),
None => {
error!(
"ailake: invariant violated — row_id {} out of bounds \
(raw_vectors.len={}, file={}); \
Parquet row count and HNSW node count are out of sync; \
file may be corrupt — run compaction to rebuild",
idx,
raw_vectors.len(),
file_entry.path
);
f32::INFINITY
}
};
all_results.push(SearchResult {
row_id,
distance: exact_dist,
file_path: file_entry.path.clone(),
});
}
} else {
for (row_id, distance) in local_results {
all_results.push(SearchResult {
row_id,
distance,
file_path: file_entry.path.clone(),
});
}
}
}
all_results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(config.top_k);
Ok(all_results)
}
#[derive(Debug, Clone)]
pub struct ModalQuery<'a> {
pub column: &'a str,
pub query: &'a [f32],
pub weight: f32,
pub dim: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusionMethod {
Rrf,
}
pub async fn search_multimodal(
table: &TableIdent,
queries: &[ModalQuery<'_>],
config: SearchConfig,
catalog: Arc<dyn CatalogProvider>,
store: Arc<dyn Store>,
fusion: FusionMethod,
) -> AilakeResult<Vec<SearchResult>> {
use std::collections::HashMap;
if queries.is_empty() {
return Err(AilakeError::InvalidArgument(
"search_multimodal requires at least one ModalQuery".into(),
));
}
let table_meta = catalog.load_table(table).await?;
let primary_col = table_meta
.properties
.get("ailake.vector-column")
.cloned()
.unwrap_or_default();
let primary_dim: u32 = table_meta
.properties
.get("ailake.vector-dim")
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let per_col_k = (config.top_k * queries.len().max(2)).min(1000);
let mut per_col_results: Vec<(f32, Vec<SearchResult>)> = Vec::with_capacity(queries.len());
for mq in queries {
let resolved_dim = if mq.dim > 0 {
mq.dim
} else if mq.column == primary_col {
primary_dim
} else {
table_meta
.properties
.get(&format!("ailake.dim-{}", mq.column))
.and_then(|s| s.parse().ok())
.unwrap_or(mq.query.len() as u32)
};
let col_config = SearchConfig {
top_k: per_col_k,
ef_search: config.ef_search,
pruning_threshold: config.pruning_threshold,
rerank_factor: config.rerank_factor,
};
let results = search(
table,
mq.query,
col_config,
mq.column,
resolved_dim,
catalog.clone(),
store.clone(),
)
.await?;
per_col_results.push((mq.weight, results));
}
const K: f32 = 60.0;
let mut scores: HashMap<(String, u64), f32> = HashMap::new();
for (weight, results) in &per_col_results {
for (rank, r) in results.iter().enumerate() {
let key = (r.file_path.clone(), r.row_id.as_u64());
let rrf = weight / (K + rank as f32 + 1.0);
*scores.entry(key).or_insert(0.0) += rrf;
}
}
let all_files = catalog.list_files(table, None).await?;
let _ = all_files;
let mut seen: HashMap<(String, u64), f32> = HashMap::new();
for (_, results) in &per_col_results {
for r in results {
let key = (r.file_path.clone(), r.row_id.as_u64());
let rrf_score = *scores.get(&key).unwrap_or(&0.0);
seen.entry(key).or_insert(rrf_score);
}
}
let mut fused: Vec<SearchResult> = seen
.into_iter()
.map(|((file_path, row_id_u64), rrf_score)| SearchResult {
row_id: RowId::new(row_id_u64),
distance: -rrf_score,
file_path,
})
.collect();
fused.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
fused.truncate(config.top_k);
let _ = fusion;
Ok(fused)
}
fn flat_search(
raw: &[Vec<f32>],
query: &[f32],
top_k: usize,
metric: VectorMetric,
) -> Vec<(RowId, f32)> {
let mut results: Vec<(RowId, f32)> = raw
.iter()
.enumerate()
.map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_k);
results
}
fn parse_metric(s: &str) -> VectorMetric {
match s {
"euclidean" => VectorMetric::Euclidean,
"dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
_ => VectorMetric::Cosine,
}
}
pub struct SearchSession {
shards: Vec<LoadedShard>,
metric: VectorMetric,
}
struct LoadedShard {
entry: DataFileEntry,
index: Option<AnyIndex>,
raw_vectors: Option<Vec<Vec<f32>>>,
}
impl SearchSession {
pub async fn load(
table: &TableIdent,
vector_column: &str,
dim: u32,
catalog: Arc<dyn CatalogProvider>,
store: Arc<dyn Store>,
load_raw: bool,
) -> AilakeResult<Self> {
let all_files = catalog.list_files(table, None).await?;
let table_meta = catalog.load_table(table).await?;
let metric = parse_metric(
table_meta
.properties
.get("ailake.vector-metric")
.map(String::as_str)
.unwrap_or("cosine"),
);
let mut shards = Vec::with_capacity(all_files.len());
for entry in all_files {
let file_bytes: Bytes = store.get(&entry.path).await?;
let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
if entry.index_status == IndexStatus::Indexing {
let (_, raw_vecs) = reader.read_parquet()?;
shards.push(LoadedShard {
entry,
index: None,
raw_vectors: Some(raw_vecs),
});
} else if reader.is_ailake_file() {
let mut index = reader.load_any_index_for_column(vector_column)?;
let raw_vectors = if load_raw {
index.quantize_to_f16();
let (_, vecs) = reader.read_parquet()?;
Some(vecs)
} else {
None
};
shards.push(LoadedShard {
entry,
index: Some(index),
raw_vectors,
});
}
}
Ok(Self { shards, metric })
}
pub fn shard_count(&self) -> usize {
self.shards.len()
}
pub fn search_batch(
&self,
queries: &[Vec<f32>],
config: &SearchConfig,
) -> Vec<Vec<SearchResult>> {
if queries.is_empty() {
return vec![];
}
let n_queries = queries.len();
let candidate_k = match config.rerank_factor {
Some(factor) => config.top_k * factor,
None => config.top_k,
};
let use_nvidia = ailake_index::hardware::detect_cuda();
let use_amd = ailake_index::hardware::detect_rocm();
let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
for shard in &self.shards {
if let Some(raw) = &shard.raw_vectors {
if !raw.is_empty() {
let dim = raw[0].len();
let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
let gpu_batch = if use_nvidia {
ailake_index::gpu::try_nvidia_search_batch(
&q_refs,
&row_ids,
&flat,
dim,
self.metric,
candidate_k,
)
} else if use_amd {
ailake_index::gpu::try_rocm_search_batch(
&q_refs,
&row_ids,
&flat,
dim,
self.metric,
candidate_k,
)
} else {
None
};
if let Some(batch) = gpu_batch {
for (qi, results) in batch.into_iter().enumerate() {
for (row_id, distance) in results {
all_results[qi].push(SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
});
}
}
continue;
}
}
for (qi, query) in queries.iter().enumerate() {
for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
all_results[qi].push(SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
});
}
}
} else if let Some(index) = &shard.index {
let shard_results: Vec<Vec<SearchResult>> = queries
.par_iter()
.map(|query| {
index
.search(query, candidate_k, config.ef_search)
.into_iter()
.map(|(row_id, distance)| SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
})
.collect()
})
.collect();
for (qi, results) in shard_results.into_iter().enumerate() {
all_results[qi].extend(results);
}
}
}
for results in &mut all_results {
results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(config.top_k);
}
all_results
}
pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
let candidate_k = match config.rerank_factor {
Some(factor) => config.top_k * factor,
None => config.top_k,
};
let mut all_results: Vec<SearchResult> = self
.shards
.par_iter()
.flat_map(|shard| {
if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
let dist = match self.metric {
VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
ailake_vec::cosine_distance(query, ¢roid.values)
}
VectorMetric::Euclidean => {
ailake_vec::euclidean_distance(query, ¢roid.values)
}
VectorMetric::DotProduct => {
-ailake_vec::dot_product(query, ¢roid.values)
}
};
if dist - centroid.radius > config.pruning_threshold {
return vec![];
}
}
if let Some(index) = &shard.index {
let local_results = index.search(query, candidate_k, config.ef_search);
if config.rerank_factor.is_some() {
if let Some(raw) = &shard.raw_vectors {
local_results
.into_iter()
.map(|(row_id, _approx_dist)| {
let idx = row_id.as_u64() as usize;
let exact_dist = raw
.get(idx)
.map(|v| exact_distance(self.metric, query, v))
.unwrap_or(f32::INFINITY);
SearchResult {
row_id,
distance: exact_dist,
file_path: shard.entry.path.clone(),
}
})
.collect()
} else {
local_results
.into_iter()
.map(|(row_id, distance)| SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
})
.collect()
}
} else {
local_results
.into_iter()
.map(|(row_id, distance)| SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
})
.collect()
}
} else if let Some(raw) = &shard.raw_vectors {
flat_search(raw, query, candidate_k, self.metric)
.into_iter()
.map(|(row_id, distance)| SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
})
.collect()
} else {
vec![]
}
})
.collect();
all_results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(config.top_k);
all_results
}
}
pub async fn fetch_rows(
results: &[SearchResult],
store: Arc<dyn Store>,
vector_column: &str,
dim: u32,
) -> AilakeResult<RecordBatch> {
use std::collections::HashMap;
use arrow_array::{ArrayRef, Float32Array, UInt32Array};
use arrow_schema::{DataType, Field, Schema};
use arrow_select::{concat::concat_batches, take::take};
if results.is_empty() {
return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
}
let mut by_file: HashMap<&str, Vec<(u64, f32, usize)>> = HashMap::new();
for (i, r) in results.iter().enumerate() {
by_file
.entry(r.file_path.as_str())
.or_default()
.push((r.row_id.as_u64(), r.distance, i));
}
use arrow_array::FixedSizeListArray;
let mut collected: Vec<(usize, f32, RecordBatch, Vec<f32>)> = Vec::with_capacity(results.len());
for (file_path, rows) in &by_file {
let bytes = store.get(file_path).await?;
let reader = AilakeFileReader::new(bytes, vector_column, dim);
let (batch, vectors) = reader.read_parquet()?;
for &(row_id, distance, pos) in rows {
let idx = row_id as usize;
if idx >= batch.num_rows() {
tracing::warn!(
"fetch_rows: row_id {} out of bounds (file_rows={}, file={}), skipping",
idx,
batch.num_rows(),
file_path
);
continue;
}
let indices = UInt32Array::from(vec![idx as u32]);
let row_cols: Vec<ArrayRef> = batch
.columns()
.iter()
.map(|col| {
take(col.as_ref(), &indices, None)
.map_err(|e| AilakeError::Arrow(e.to_string()))
})
.collect::<AilakeResult<Vec<_>>>()?;
let row_batch = RecordBatch::try_new(batch.schema(), row_cols)
.map_err(|e| AilakeError::Arrow(e.to_string()))?;
let vec = vectors
.get(idx)
.cloned()
.unwrap_or_else(|| vec![0.0f32; dim as usize]);
collected.push((pos, distance, row_batch, vec));
}
}
if collected.is_empty() {
return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
}
collected.sort_by_key(|(pos, _, _, _)| *pos);
let distances: Vec<f32> = collected.iter().map(|(_, d, _, _)| *d).collect();
let row_batches: Vec<&RecordBatch> = collected.iter().map(|(_, _, b, _)| b).collect();
let base_schema = collected[0].2.schema();
let combined =
concat_batches(&base_schema, row_batches).map_err(|e| AilakeError::Arrow(e.to_string()))?;
let flat_vecs: Vec<f32> = collected
.iter()
.flat_map(|(_, _, _, v)| v.iter().copied())
.collect();
let item_field = Arc::new(Field::new("item", DataType::Float32, false));
let values_arr = Arc::new(Float32Array::from(flat_vecs)) as ArrayRef;
let vec_col = FixedSizeListArray::new(item_field.clone(), dim as i32, values_arr, None);
let vec_field = Arc::new(Field::new(
vector_column,
DataType::FixedSizeList(item_field, dim as i32),
false,
));
let mut fields: Vec<Arc<Field>> = base_schema.fields().to_vec();
fields.push(vec_field);
fields.push(Arc::new(Field::new("_distance", DataType::Float32, false)));
let new_schema = Arc::new(Schema::new(fields));
let mut columns: Vec<ArrayRef> = combined.columns().to_vec();
columns.push(Arc::new(vec_col));
columns.push(Arc::new(Float32Array::from(distances)));
RecordBatch::try_new(new_schema, columns).map_err(|e| AilakeError::Arrow(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::writer::MultiVectorBatch;
use ailake_catalog::{HadoopCatalog, TableIdent};
use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
use ailake_store::LocalStore;
use arrow_array::{Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use tempfile::TempDir;
fn make_policy(dim: u32) -> VectorStoragePolicy {
VectorStoragePolicy {
column_name: "embedding".to_string(),
dim,
metric: VectorMetric::Cosine,
precision: VectorPrecision::F16,
pq: None,
keep_raw_for_reranking: true,
pre_normalize: false,
hnsw_m: None,
hnsw_ef_construction: None,
ivf_residual: false,
embedding_model: None,
modality: None,
}
}
async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
let table = TableIdent::new("default", "table");
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let ids: Vec<i32> = (0..rows as i32).collect();
let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
let embeddings: Vec<Vec<f32>> = (0..rows)
.map(|i| {
let mut v = vec![0.0f32; dim];
v[i % dim] = 1.0;
v
})
.collect();
let mut writer =
crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
.await
.unwrap();
writer.write_batch(&batch, &embeddings).await.unwrap();
writer.commit().await.unwrap();
}
#[tokio::test]
async fn rerank_returns_correct_top_k_count() {
let dir = TempDir::new().unwrap();
let dim = 8usize;
write_demo_table(&dir, dim, 8).await;
let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let catalog: Arc<dyn CatalogProvider> =
Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
let table = TableIdent::new("default", "table");
let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let config = SearchConfig {
top_k: 3,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: Some(2),
};
let results = search(
&table,
&query,
config,
"embedding",
dim as u32,
catalog,
store,
)
.await
.unwrap();
assert_eq!(results.len(), 3);
}
#[tokio::test]
async fn rerank_nearest_is_exact_match() {
let dir = TempDir::new().unwrap();
let dim = 8usize;
write_demo_table(&dir, dim, 8).await;
let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let catalog: Arc<dyn CatalogProvider> =
Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
let table = TableIdent::new("default", "table");
let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let config = SearchConfig {
top_k: 1,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: Some(4),
};
let results = search(
&table,
&query,
config,
"embedding",
dim as u32,
catalog,
store,
)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert!(
results[0].distance < 1e-3,
"distance was {}",
results[0].distance
);
assert_eq!(results[0].row_id, RowId::new(0));
}
#[tokio::test]
async fn no_rerank_matches_default_behavior() {
let dir = TempDir::new().unwrap();
let dim = 4usize;
write_demo_table(&dir, dim, 4).await;
let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let cat_a: Arc<dyn CatalogProvider> =
Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
let cat_b: Arc<dyn CatalogProvider> =
Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
let table = TableIdent::new("default", "table");
let query = vec![1.0f32, 0.0, 0.0, 0.0];
let cfg_plain = SearchConfig {
top_k: 2,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: None,
};
let cfg_rerank = SearchConfig {
top_k: 2,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: Some(2),
};
let plain = search(
&table,
&query,
cfg_plain,
"embedding",
dim as u32,
cat_a,
store_a,
)
.await
.unwrap();
let reranked = search(
&table,
&query,
cfg_rerank,
"embedding",
dim as u32,
cat_b,
store_b,
)
.await
.unwrap();
assert_eq!(plain[0].row_id, reranked[0].row_id);
}
#[tokio::test]
async fn multimodal_rrf_returns_top_k() {
let dir = TempDir::new().unwrap();
let dim = 4usize;
write_demo_table(&dir, dim, 4).await;
let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let catalog: Arc<dyn CatalogProvider> =
Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
let table = TableIdent::new("default", "table");
let q1 = vec![1.0f32, 0.0, 0.0, 0.0];
let q2 = vec![0.0f32, 1.0, 0.0, 0.0];
let queries = vec![
ModalQuery {
column: "embedding",
query: &q1,
weight: 0.7,
dim: dim as u32,
},
ModalQuery {
column: "embedding",
query: &q2,
weight: 0.3,
dim: dim as u32,
},
];
let config = SearchConfig {
top_k: 2,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: None,
};
let results =
search_multimodal(&table, &queries, config, catalog, store, FusionMethod::Rrf)
.await
.unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].distance <= 0.0);
assert!(results[0].row_id.as_u64() < 4);
}
#[tokio::test]
async fn multimodal_rrf_cross_modal_different_dims() {
let dir = TempDir::new().unwrap();
let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let catalog: Arc<dyn CatalogProvider> =
Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
let table = TableIdent::new("default", "table");
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let rows = 4usize;
let ids: Vec<i32> = (0..rows as i32).collect();
let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
let text_embs: Vec<Vec<f32>> = (0..rows)
.map(|i| {
let mut v = vec![0.0f32; 4];
v[i % 4] = 1.0;
v
})
.collect();
let img_embs: Vec<Vec<f32>> = (0..rows)
.map(|i| {
let mut v = vec![0.0f32; 2];
v[i % 2] = 1.0;
v
})
.collect();
let text_policy = make_policy(4);
let img_policy = VectorStoragePolicy {
column_name: "img_embedding".to_string(),
dim: 2,
metric: VectorMetric::Cosine,
precision: VectorPrecision::F16,
pq: None,
keep_raw_for_reranking: true,
pre_normalize: false,
hnsw_m: None,
hnsw_ef_construction: None,
ivf_residual: false,
embedding_model: None,
modality: None,
};
let mut writer = crate::TableWriter::create_or_open(
catalog.clone(),
store.clone(),
text_policy,
table.clone(),
)
.await
.unwrap();
let batches = [
MultiVectorBatch {
policy: make_policy(4),
embeddings: &text_embs,
},
MultiVectorBatch {
policy: img_policy,
embeddings: &img_embs,
},
];
writer.write_batch_multi(&batch, &batches).await.unwrap();
writer.commit().await.unwrap();
let q_text = vec![1.0f32, 0.0, 0.0, 0.0];
let q_img = vec![1.0f32, 0.0];
let queries = vec![
ModalQuery {
column: "embedding",
query: &q_text,
weight: 0.6,
dim: 4,
},
ModalQuery {
column: "img_embedding",
query: &q_img,
weight: 0.4,
dim: 2,
},
];
let config = SearchConfig {
top_k: 2,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: None,
};
let results =
search_multimodal(&table, &queries, config, catalog, store, FusionMethod::Rrf)
.await
.unwrap();
assert!(!results.is_empty(), "should return results");
assert!(results[0].distance <= 0.0, "distance is -rrf_score");
assert_eq!(results[0].row_id.as_u64(), 0, "row 0 should rank first");
}
}