use std::sync::Arc;
use rayon::prelude::*;
use tracing::{debug, error};
use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
use ailake_file::AilakeFileReader;
use ailake_index::AnyIndex;
use ailake_store::Store;
use ailake_vec::exact_distance;
use arrow_array::RecordBatch;
use bytes::Bytes;
use crate::pruner::VectorPruner;
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub top_k: usize,
pub ef_search: usize,
pub pruning_threshold: f32,
pub rerank_factor: Option<usize>,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
top_k: 10,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: None,
}
}
}
impl SearchConfig {
pub fn with_pruning(mut self, threshold: f32) -> Self {
self.pruning_threshold = threshold;
self
}
pub fn with_reranking(mut self, factor: usize) -> Self {
self.rerank_factor = Some(factor);
self
}
}
#[derive(Debug)]
pub struct SearchResult {
pub row_id: RowId,
pub distance: f32,
pub file_path: String,
}
pub async fn search(
table: &TableIdent,
query: &[f32],
config: SearchConfig,
vector_column: &str,
dim: u32,
catalog: Arc<dyn CatalogProvider>,
store: Arc<dyn Store>,
) -> AilakeResult<Vec<SearchResult>> {
let all_files = catalog.list_files(table, None).await?;
let table_meta = catalog.load_table(table).await?;
let metric = parse_metric(
table_meta
.properties
.get("ailake.vector-metric")
.map(String::as_str)
.unwrap_or("cosine"),
);
let total_files = all_files.len();
let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
debug!(
"ailake: geometric pruning — {}/{} files survive (threshold={})",
surviving_files.len(),
total_files,
config.pruning_threshold
);
let candidate_k = match config.rerank_factor {
Some(factor) => config.top_k * factor,
None => config.top_k,
};
let mut all_results: Vec<SearchResult> = Vec::new();
for file_entry in &surviving_files {
let file_bytes: Bytes = store.get(&file_entry.path).await?;
let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
debug!(
"ailake: flat scan fallback for {} (index_status={:?})",
file_entry.path, file_entry.index_status
);
let (_, raw_vectors) = reader.read_parquet()?;
for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
all_results.push(SearchResult {
row_id,
distance,
file_path: file_entry.path.clone(),
});
}
continue;
}
let index = reader.load_any_index_for_column(vector_column)?;
let local_results = index.search(query, candidate_k, config.ef_search);
if config.rerank_factor.is_some() {
let (_, raw_vectors) = reader.read_parquet()?;
for (row_id, _approx_dist) in local_results {
let idx = row_id.as_u64() as usize;
let exact_dist = match raw_vectors.get(idx) {
Some(v) => exact_distance(metric, query, v),
None => {
error!(
"ailake: invariant violated — row_id {} out of bounds \
(raw_vectors.len={}, file={}); \
Parquet row count and HNSW node count are out of sync; \
file may be corrupt — run compaction to rebuild",
idx,
raw_vectors.len(),
file_entry.path
);
f32::INFINITY
}
};
all_results.push(SearchResult {
row_id,
distance: exact_dist,
file_path: file_entry.path.clone(),
});
}
} else {
for (row_id, distance) in local_results {
all_results.push(SearchResult {
row_id,
distance,
file_path: file_entry.path.clone(),
});
}
}
}
all_results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(config.top_k);
Ok(all_results)
}
fn flat_search(
raw: &[Vec<f32>],
query: &[f32],
top_k: usize,
metric: VectorMetric,
) -> Vec<(RowId, f32)> {
let mut results: Vec<(RowId, f32)> = raw
.iter()
.enumerate()
.map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_k);
results
}
fn parse_metric(s: &str) -> VectorMetric {
match s {
"euclidean" => VectorMetric::Euclidean,
"dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
_ => VectorMetric::Cosine,
}
}
pub struct SearchSession {
shards: Vec<LoadedShard>,
metric: VectorMetric,
}
struct LoadedShard {
entry: DataFileEntry,
index: Option<AnyIndex>,
raw_vectors: Option<Vec<Vec<f32>>>,
}
impl SearchSession {
pub async fn load(
table: &TableIdent,
vector_column: &str,
dim: u32,
catalog: Arc<dyn CatalogProvider>,
store: Arc<dyn Store>,
load_raw: bool,
) -> AilakeResult<Self> {
let all_files = catalog.list_files(table, None).await?;
let table_meta = catalog.load_table(table).await?;
let metric = parse_metric(
table_meta
.properties
.get("ailake.vector-metric")
.map(String::as_str)
.unwrap_or("cosine"),
);
let mut shards = Vec::with_capacity(all_files.len());
for entry in all_files {
let file_bytes: Bytes = store.get(&entry.path).await?;
let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
if entry.index_status == IndexStatus::Indexing {
let (_, raw_vecs) = reader.read_parquet()?;
shards.push(LoadedShard {
entry,
index: None,
raw_vectors: Some(raw_vecs),
});
} else if reader.is_ailake_file() {
let mut index = reader.load_any_index_for_column(vector_column)?;
let raw_vectors = if load_raw {
index.quantize_to_f16();
let (_, vecs) = reader.read_parquet()?;
Some(vecs)
} else {
None
};
shards.push(LoadedShard {
entry,
index: Some(index),
raw_vectors,
});
}
}
Ok(Self { shards, metric })
}
pub fn shard_count(&self) -> usize {
self.shards.len()
}
pub fn search_batch(
&self,
queries: &[Vec<f32>],
config: &SearchConfig,
) -> Vec<Vec<SearchResult>> {
if queries.is_empty() {
return vec![];
}
let n_queries = queries.len();
let candidate_k = match config.rerank_factor {
Some(factor) => config.top_k * factor,
None => config.top_k,
};
let use_nvidia = ailake_index::hardware::detect_cuda();
let use_amd = ailake_index::hardware::detect_rocm();
let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
for shard in &self.shards {
if let Some(raw) = &shard.raw_vectors {
if !raw.is_empty() {
let dim = raw[0].len();
let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
let gpu_batch = if use_nvidia {
ailake_index::gpu::try_nvidia_search_batch(
&q_refs,
&row_ids,
&flat,
dim,
self.metric,
candidate_k,
)
} else if use_amd {
ailake_index::gpu::try_rocm_search_batch(
&q_refs,
&row_ids,
&flat,
dim,
self.metric,
candidate_k,
)
} else {
None
};
if let Some(batch) = gpu_batch {
for (qi, results) in batch.into_iter().enumerate() {
for (row_id, distance) in results {
all_results[qi].push(SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
});
}
}
continue;
}
}
for (qi, query) in queries.iter().enumerate() {
for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
all_results[qi].push(SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
});
}
}
} else if let Some(index) = &shard.index {
let shard_results: Vec<Vec<SearchResult>> = queries
.par_iter()
.map(|query| {
index
.search(query, candidate_k, config.ef_search)
.into_iter()
.map(|(row_id, distance)| SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
})
.collect()
})
.collect();
for (qi, results) in shard_results.into_iter().enumerate() {
all_results[qi].extend(results);
}
}
}
for results in &mut all_results {
results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(config.top_k);
}
all_results
}
pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
let candidate_k = match config.rerank_factor {
Some(factor) => config.top_k * factor,
None => config.top_k,
};
let mut all_results: Vec<SearchResult> = self
.shards
.par_iter()
.flat_map(|shard| {
if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
let dist = match self.metric {
VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
ailake_vec::cosine_distance(query, ¢roid.values)
}
VectorMetric::Euclidean => {
ailake_vec::euclidean_distance(query, ¢roid.values)
}
VectorMetric::DotProduct => {
-ailake_vec::dot_product(query, ¢roid.values)
}
};
if dist - centroid.radius > config.pruning_threshold {
return vec![];
}
}
if let Some(index) = &shard.index {
let local_results = index.search(query, candidate_k, config.ef_search);
if config.rerank_factor.is_some() {
if let Some(raw) = &shard.raw_vectors {
local_results
.into_iter()
.map(|(row_id, _approx_dist)| {
let idx = row_id.as_u64() as usize;
let exact_dist = raw
.get(idx)
.map(|v| exact_distance(self.metric, query, v))
.unwrap_or(f32::INFINITY);
SearchResult {
row_id,
distance: exact_dist,
file_path: shard.entry.path.clone(),
}
})
.collect()
} else {
local_results
.into_iter()
.map(|(row_id, distance)| SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
})
.collect()
}
} else {
local_results
.into_iter()
.map(|(row_id, distance)| SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
})
.collect()
}
} else if let Some(raw) = &shard.raw_vectors {
flat_search(raw, query, candidate_k, self.metric)
.into_iter()
.map(|(row_id, distance)| SearchResult {
row_id,
distance,
file_path: shard.entry.path.clone(),
})
.collect()
} else {
vec![]
}
})
.collect();
all_results.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(config.top_k);
all_results
}
}
pub async fn fetch_rows(
results: &[SearchResult],
store: Arc<dyn Store>,
vector_column: &str,
dim: u32,
) -> AilakeResult<RecordBatch> {
use std::collections::HashMap;
use arrow_array::{ArrayRef, Float32Array, UInt32Array};
use arrow_schema::{DataType, Field, Schema};
use arrow_select::{concat::concat_batches, take::take};
if results.is_empty() {
return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
}
let mut by_file: HashMap<&str, Vec<(u64, f32, usize)>> = HashMap::new();
for (i, r) in results.iter().enumerate() {
by_file
.entry(r.file_path.as_str())
.or_default()
.push((r.row_id.as_u64(), r.distance, i));
}
use arrow_array::FixedSizeListArray;
let mut collected: Vec<(usize, f32, RecordBatch, Vec<f32>)> = Vec::with_capacity(results.len());
for (file_path, rows) in &by_file {
let bytes = store.get(file_path).await?;
let reader = AilakeFileReader::new(bytes, vector_column, dim);
let (batch, vectors) = reader.read_parquet()?;
for &(row_id, distance, pos) in rows {
let idx = row_id as usize;
if idx >= batch.num_rows() {
tracing::warn!(
"fetch_rows: row_id {} out of bounds (file_rows={}, file={}), skipping",
idx,
batch.num_rows(),
file_path
);
continue;
}
let indices = UInt32Array::from(vec![idx as u32]);
let row_cols: Vec<ArrayRef> = batch
.columns()
.iter()
.map(|col| {
take(col.as_ref(), &indices, None)
.map_err(|e| AilakeError::Arrow(e.to_string()))
})
.collect::<AilakeResult<Vec<_>>>()?;
let row_batch = RecordBatch::try_new(batch.schema(), row_cols)
.map_err(|e| AilakeError::Arrow(e.to_string()))?;
let vec = vectors
.get(idx)
.cloned()
.unwrap_or_else(|| vec![0.0f32; dim as usize]);
collected.push((pos, distance, row_batch, vec));
}
}
if collected.is_empty() {
return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
}
collected.sort_by_key(|(pos, _, _, _)| *pos);
let distances: Vec<f32> = collected.iter().map(|(_, d, _, _)| *d).collect();
let row_batches: Vec<&RecordBatch> = collected.iter().map(|(_, _, b, _)| b).collect();
let base_schema = collected[0].2.schema();
let combined =
concat_batches(&base_schema, row_batches).map_err(|e| AilakeError::Arrow(e.to_string()))?;
let flat_vecs: Vec<f32> = collected
.iter()
.flat_map(|(_, _, _, v)| v.iter().copied())
.collect();
let item_field = Arc::new(Field::new("item", DataType::Float32, false));
let values_arr = Arc::new(Float32Array::from(flat_vecs)) as ArrayRef;
let vec_col = FixedSizeListArray::new(item_field.clone(), dim as i32, values_arr, None);
let vec_field = Arc::new(Field::new(
vector_column,
DataType::FixedSizeList(item_field, dim as i32),
false,
));
let mut fields: Vec<Arc<Field>> = base_schema.fields().to_vec();
fields.push(vec_field);
fields.push(Arc::new(Field::new("_distance", DataType::Float32, false)));
let new_schema = Arc::new(Schema::new(fields));
let mut columns: Vec<ArrayRef> = combined.columns().to_vec();
columns.push(Arc::new(vec_col));
columns.push(Arc::new(Float32Array::from(distances)));
RecordBatch::try_new(new_schema, columns).map_err(|e| AilakeError::Arrow(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use ailake_catalog::{HadoopCatalog, TableIdent};
use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
use ailake_store::LocalStore;
use arrow_array::{Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
use tempfile::TempDir;
fn make_policy(dim: u32) -> VectorStoragePolicy {
VectorStoragePolicy {
column_name: "embedding".to_string(),
dim,
metric: VectorMetric::Cosine,
precision: VectorPrecision::F16,
pq: None,
keep_raw_for_reranking: false,
pre_normalize: false,
hnsw_m: None,
hnsw_ef_construction: None,
}
}
async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
let table = TableIdent::new("default", "table");
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let ids: Vec<i32> = (0..rows as i32).collect();
let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
let embeddings: Vec<Vec<f32>> = (0..rows)
.map(|i| {
let mut v = vec![0.0f32; dim];
v[i % dim] = 1.0;
v
})
.collect();
let mut writer =
crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
.await
.unwrap();
writer.write_batch(&batch, &embeddings).await.unwrap();
writer.commit().await.unwrap();
}
#[tokio::test]
async fn rerank_returns_correct_top_k_count() {
let dir = TempDir::new().unwrap();
let dim = 8usize;
write_demo_table(&dir, dim, 8).await;
let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let catalog: Arc<dyn CatalogProvider> =
Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
let table = TableIdent::new("default", "table");
let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let config = SearchConfig {
top_k: 3,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: Some(2),
};
let results = search(
&table,
&query,
config,
"embedding",
dim as u32,
catalog,
store,
)
.await
.unwrap();
assert_eq!(results.len(), 3);
}
#[tokio::test]
async fn rerank_nearest_is_exact_match() {
let dir = TempDir::new().unwrap();
let dim = 8usize;
write_demo_table(&dir, dim, 8).await;
let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let catalog: Arc<dyn CatalogProvider> =
Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
let table = TableIdent::new("default", "table");
let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let config = SearchConfig {
top_k: 1,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: Some(4),
};
let results = search(
&table,
&query,
config,
"embedding",
dim as u32,
catalog,
store,
)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert!(
results[0].distance < 1e-3,
"distance was {}",
results[0].distance
);
assert_eq!(results[0].row_id, RowId::new(0));
}
#[tokio::test]
async fn no_rerank_matches_default_behavior() {
let dir = TempDir::new().unwrap();
let dim = 4usize;
write_demo_table(&dir, dim, 4).await;
let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
let cat_a: Arc<dyn CatalogProvider> =
Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
let cat_b: Arc<dyn CatalogProvider> =
Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
let table = TableIdent::new("default", "table");
let query = vec![1.0f32, 0.0, 0.0, 0.0];
let cfg_plain = SearchConfig {
top_k: 2,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: None,
};
let cfg_rerank = SearchConfig {
top_k: 2,
ef_search: 50,
pruning_threshold: f32::INFINITY,
rerank_factor: Some(2),
};
let plain = search(
&table,
&query,
cfg_plain,
"embedding",
dim as u32,
cat_a,
store_a,
)
.await
.unwrap();
let reranked = search(
&table,
&query,
cfg_rerank,
"embedding",
dim as u32,
cat_b,
store_b,
)
.await
.unwrap();
assert_eq!(plain[0].row_id, reranked[0].row_id);
}
}