1use std::sync::Arc;
3
4use rayon::prelude::*;
5use tracing::{debug, error};
6
7use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
8use ailake_core::{AilakeResult, RowId, VectorMetric};
9use ailake_file::AilakeFileReader;
10use ailake_index::AnyIndex;
11use ailake_store::Store;
12use ailake_vec::exact_distance;
13use bytes::Bytes;
14
15use crate::pruner::VectorPruner;
16
17#[derive(Debug, Clone)]
18pub struct SearchConfig {
19 pub top_k: usize,
20 pub ef_search: usize,
21 pub pruning_threshold: f32,
25 pub rerank_factor: Option<usize>,
30}
31
32impl Default for SearchConfig {
33 fn default() -> Self {
34 Self {
35 top_k: 10,
36 ef_search: 50,
37 pruning_threshold: f32::INFINITY,
38 rerank_factor: None,
39 }
40 }
41}
42
43impl SearchConfig {
44 pub fn with_pruning(mut self, threshold: f32) -> Self {
45 self.pruning_threshold = threshold;
46 self
47 }
48
49 pub fn with_reranking(mut self, factor: usize) -> Self {
50 self.rerank_factor = Some(factor);
51 self
52 }
53}
54
55#[derive(Debug)]
56pub struct SearchResult {
57 pub row_id: RowId,
58 pub distance: f32,
59 pub file_path: String,
60}
61
62pub async fn search(
70 table: &TableIdent,
71 query: &[f32],
72 config: SearchConfig,
73 vector_column: &str,
74 dim: u32,
75 catalog: Arc<dyn CatalogProvider>,
76 store: Arc<dyn Store>,
77) -> AilakeResult<Vec<SearchResult>> {
78 let all_files = catalog.list_files(table, None).await?;
80
81 let table_meta = catalog.load_table(table).await?;
83 let metric = parse_metric(
84 table_meta
85 .properties
86 .get("ailake.vector-metric")
87 .map(String::as_str)
88 .unwrap_or("cosine"),
89 );
90
91 let total_files = all_files.len();
93 let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
94 debug!(
95 "ailake: geometric pruning — {}/{} files survive (threshold={})",
96 surviving_files.len(),
97 total_files,
98 config.pruning_threshold
99 );
100
101 let candidate_k = match config.rerank_factor {
102 Some(factor) => config.top_k * factor,
103 None => config.top_k,
104 };
105
106 let mut all_results: Vec<SearchResult> = Vec::new();
107
108 for file_entry in &surviving_files {
109 let file_bytes: Bytes = store.get(&file_entry.path).await?;
110 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
111
112 if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
113 debug!(
115 "ailake: flat scan fallback for {} (index_status={:?})",
116 file_entry.path, file_entry.index_status
117 );
118 let (_, raw_vectors) = reader.read_parquet()?;
119 for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
120 all_results.push(SearchResult {
121 row_id,
122 distance,
123 file_path: file_entry.path.clone(),
124 });
125 }
126 continue;
127 }
128
129 let index = reader.load_any_index_for_column(vector_column)?;
130 let local_results = index.search(query, candidate_k, config.ef_search);
131
132 if config.rerank_factor.is_some() {
133 let (_, raw_vectors) = reader.read_parquet()?;
135 for (row_id, _approx_dist) in local_results {
136 let idx = row_id.as_u64() as usize;
137 let exact_dist = match raw_vectors.get(idx) {
138 Some(v) => exact_distance(metric, query, v),
139 None => {
140 error!(
141 "ailake: invariant violated — row_id {} out of bounds \
142 (raw_vectors.len={}, file={}); \
143 Parquet row count and HNSW node count are out of sync; \
144 file may be corrupt — run compaction to rebuild",
145 idx,
146 raw_vectors.len(),
147 file_entry.path
148 );
149 f32::INFINITY
150 }
151 };
152 all_results.push(SearchResult {
153 row_id,
154 distance: exact_dist,
155 file_path: file_entry.path.clone(),
156 });
157 }
158 } else {
159 for (row_id, distance) in local_results {
160 all_results.push(SearchResult {
161 row_id,
162 distance,
163 file_path: file_entry.path.clone(),
164 });
165 }
166 }
167 }
168
169 all_results.sort_by(|a, b| {
171 a.distance
172 .partial_cmp(&b.distance)
173 .unwrap_or(std::cmp::Ordering::Equal)
174 });
175 all_results.truncate(config.top_k);
176 Ok(all_results)
177}
178
179fn flat_search(
181 raw: &[Vec<f32>],
182 query: &[f32],
183 top_k: usize,
184 metric: VectorMetric,
185) -> Vec<(RowId, f32)> {
186 let mut results: Vec<(RowId, f32)> = raw
187 .iter()
188 .enumerate()
189 .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
190 .collect();
191 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
192 results.truncate(top_k);
193 results
194}
195
196fn parse_metric(s: &str) -> VectorMetric {
197 match s {
198 "euclidean" => VectorMetric::Euclidean,
199 "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
200 _ => VectorMetric::Cosine,
201 }
202}
203
204pub struct SearchSession {
209 shards: Vec<LoadedShard>,
210 metric: VectorMetric,
211}
212
213struct LoadedShard {
214 entry: DataFileEntry,
215 index: Option<AnyIndex>,
217 raw_vectors: Option<Vec<Vec<f32>>>,
220}
221
222impl SearchSession {
223 pub async fn load(
229 table: &TableIdent,
230 vector_column: &str,
231 dim: u32,
232 catalog: Arc<dyn CatalogProvider>,
233 store: Arc<dyn Store>,
234 load_raw: bool,
235 ) -> AilakeResult<Self> {
236 let all_files = catalog.list_files(table, None).await?;
237 let table_meta = catalog.load_table(table).await?;
238 let metric = parse_metric(
239 table_meta
240 .properties
241 .get("ailake.vector-metric")
242 .map(String::as_str)
243 .unwrap_or("cosine"),
244 );
245
246 let mut shards = Vec::with_capacity(all_files.len());
247 for entry in all_files {
248 let file_bytes: Bytes = store.get(&entry.path).await?;
249 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
250
251 if entry.index_status == IndexStatus::Indexing {
252 let (_, raw_vecs) = reader.read_parquet()?;
254 shards.push(LoadedShard {
255 entry,
256 index: None,
257 raw_vectors: Some(raw_vecs),
258 });
259 } else if reader.is_ailake_file() {
260 let mut index = reader.load_any_index_for_column(vector_column)?;
261 let raw_vectors = if load_raw {
262 index.quantize_to_f16();
263 let (_, vecs) = reader.read_parquet()?;
264 Some(vecs)
265 } else {
266 None
267 };
268 shards.push(LoadedShard {
269 entry,
270 index: Some(index),
271 raw_vectors,
272 });
273 }
274 }
275
276 Ok(Self { shards, metric })
277 }
278
279 pub fn shard_count(&self) -> usize {
281 self.shards.len()
282 }
283
284 pub fn search_batch(
293 &self,
294 queries: &[Vec<f32>],
295 config: &SearchConfig,
296 ) -> Vec<Vec<SearchResult>> {
297 if queries.is_empty() {
298 return vec![];
299 }
300
301 let n_queries = queries.len();
302 let candidate_k = match config.rerank_factor {
303 Some(factor) => config.top_k * factor,
304 None => config.top_k,
305 };
306 let use_nvidia = ailake_index::hardware::detect_cuda();
307 let use_amd = ailake_index::hardware::detect_rocm();
308
309 let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
311
312 for shard in &self.shards {
313 if let Some(raw) = &shard.raw_vectors {
314 if !raw.is_empty() {
316 let dim = raw[0].len();
317 let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
318 let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
319 let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
320
321 let gpu_batch = if use_nvidia {
322 ailake_index::gpu::try_nvidia_search_batch(
323 &q_refs,
324 &row_ids,
325 &flat,
326 dim,
327 self.metric,
328 candidate_k,
329 )
330 } else if use_amd {
331 ailake_index::gpu::try_rocm_search_batch(
332 &q_refs,
333 &row_ids,
334 &flat,
335 dim,
336 self.metric,
337 candidate_k,
338 )
339 } else {
340 None
341 };
342
343 if let Some(batch) = gpu_batch {
344 for (qi, results) in batch.into_iter().enumerate() {
345 for (row_id, distance) in results {
346 all_results[qi].push(SearchResult {
347 row_id,
348 distance,
349 file_path: shard.entry.path.clone(),
350 });
351 }
352 }
353 continue;
354 }
355 }
356
357 for (qi, query) in queries.iter().enumerate() {
359 for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
360 all_results[qi].push(SearchResult {
361 row_id,
362 distance,
363 file_path: shard.entry.path.clone(),
364 });
365 }
366 }
367 } else if let Some(index) = &shard.index {
368 let shard_results: Vec<Vec<SearchResult>> = queries
370 .par_iter()
371 .map(|query| {
372 index
373 .search(query, candidate_k, config.ef_search)
374 .into_iter()
375 .map(|(row_id, distance)| SearchResult {
376 row_id,
377 distance,
378 file_path: shard.entry.path.clone(),
379 })
380 .collect()
381 })
382 .collect();
383
384 for (qi, results) in shard_results.into_iter().enumerate() {
385 all_results[qi].extend(results);
386 }
387 }
388 }
389
390 for results in &mut all_results {
392 results.sort_by(|a, b| {
393 a.distance
394 .partial_cmp(&b.distance)
395 .unwrap_or(std::cmp::Ordering::Equal)
396 });
397 results.truncate(config.top_k);
398 }
399
400 all_results
401 }
402
403 pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
405 let candidate_k = match config.rerank_factor {
406 Some(factor) => config.top_k * factor,
407 None => config.top_k,
408 };
409
410 let mut all_results: Vec<SearchResult> = self
411 .shards
412 .par_iter()
413 .flat_map(|shard| {
414 if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
416 let dist = match self.metric {
417 VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
418 ailake_vec::cosine_distance(query, ¢roid.values)
419 }
420 VectorMetric::Euclidean => {
421 ailake_vec::euclidean_distance(query, ¢roid.values)
422 }
423 VectorMetric::DotProduct => {
424 -ailake_vec::dot_product(query, ¢roid.values)
425 }
426 };
427 if dist - centroid.radius > config.pruning_threshold {
428 return vec![];
429 }
430 }
431
432 if let Some(index) = &shard.index {
433 let local_results = index.search(query, candidate_k, config.ef_search);
435 if config.rerank_factor.is_some() {
436 if let Some(raw) = &shard.raw_vectors {
437 local_results
438 .into_iter()
439 .map(|(row_id, _approx_dist)| {
440 let idx = row_id.as_u64() as usize;
441 let exact_dist = raw
442 .get(idx)
443 .map(|v| exact_distance(self.metric, query, v))
444 .unwrap_or(f32::INFINITY);
445 SearchResult {
446 row_id,
447 distance: exact_dist,
448 file_path: shard.entry.path.clone(),
449 }
450 })
451 .collect()
452 } else {
453 local_results
454 .into_iter()
455 .map(|(row_id, distance)| SearchResult {
456 row_id,
457 distance,
458 file_path: shard.entry.path.clone(),
459 })
460 .collect()
461 }
462 } else {
463 local_results
464 .into_iter()
465 .map(|(row_id, distance)| SearchResult {
466 row_id,
467 distance,
468 file_path: shard.entry.path.clone(),
469 })
470 .collect()
471 }
472 } else if let Some(raw) = &shard.raw_vectors {
473 flat_search(raw, query, candidate_k, self.metric)
475 .into_iter()
476 .map(|(row_id, distance)| SearchResult {
477 row_id,
478 distance,
479 file_path: shard.entry.path.clone(),
480 })
481 .collect()
482 } else {
483 vec![]
484 }
485 })
486 .collect();
487
488 all_results.sort_by(|a, b| {
489 a.distance
490 .partial_cmp(&b.distance)
491 .unwrap_or(std::cmp::Ordering::Equal)
492 });
493 all_results.truncate(config.top_k);
494 all_results
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use ailake_catalog::{HadoopCatalog, TableIdent};
502 use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
503 use ailake_store::LocalStore;
504 use arrow_array::{Int32Array, RecordBatch};
505 use arrow_schema::{DataType, Field, Schema};
506 use std::sync::Arc;
507 use tempfile::TempDir;
508
509 fn make_policy(dim: u32) -> VectorStoragePolicy {
510 VectorStoragePolicy {
511 column_name: "embedding".to_string(),
512 dim,
513 metric: VectorMetric::Cosine,
514 precision: VectorPrecision::F16,
515 pq: None,
516 keep_raw_for_reranking: false,
517 pre_normalize: false,
518 hnsw_m: None,
519 hnsw_ef_construction: None,
520 rabitq: None,
521 }
522 }
523
524 async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
525 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
526 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
527 let table = TableIdent::new("default", "table");
528
529 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
530 let ids: Vec<i32> = (0..rows as i32).collect();
531 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
532
533 let embeddings: Vec<Vec<f32>> = (0..rows)
535 .map(|i| {
536 let mut v = vec![0.0f32; dim];
537 v[i % dim] = 1.0;
538 v
539 })
540 .collect();
541
542 let mut writer =
543 crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
544 .await
545 .unwrap();
546 writer.write_batch(&batch, &embeddings).await.unwrap();
547 writer.commit().await.unwrap();
548 }
549
550 #[tokio::test]
551 async fn rerank_returns_correct_top_k_count() {
552 let dir = TempDir::new().unwrap();
553 let dim = 8usize;
554 write_demo_table(&dir, dim, 8).await;
555
556 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
557 let catalog: Arc<dyn CatalogProvider> =
558 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
559 let table = TableIdent::new("default", "table");
560
561 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
562 let config = SearchConfig {
563 top_k: 3,
564 ef_search: 50,
565 pruning_threshold: f32::INFINITY,
566 rerank_factor: Some(2),
567 };
568
569 let results = search(
570 &table,
571 &query,
572 config,
573 "embedding",
574 dim as u32,
575 catalog,
576 store,
577 )
578 .await
579 .unwrap();
580
581 assert_eq!(results.len(), 3);
582 }
583
584 #[tokio::test]
585 async fn rerank_nearest_is_exact_match() {
586 let dir = TempDir::new().unwrap();
587 let dim = 8usize;
588 write_demo_table(&dir, dim, 8).await;
589
590 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
591 let catalog: Arc<dyn CatalogProvider> =
592 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
593 let table = TableIdent::new("default", "table");
594
595 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
597 let config = SearchConfig {
598 top_k: 1,
599 ef_search: 50,
600 pruning_threshold: f32::INFINITY,
601 rerank_factor: Some(4),
602 };
603
604 let results = search(
605 &table,
606 &query,
607 config,
608 "embedding",
609 dim as u32,
610 catalog,
611 store,
612 )
613 .await
614 .unwrap();
615
616 assert_eq!(results.len(), 1);
617 assert!(
619 results[0].distance < 1e-3,
620 "distance was {}",
621 results[0].distance
622 );
623 assert_eq!(results[0].row_id, RowId::new(0));
624 }
625
626 #[tokio::test]
627 async fn no_rerank_matches_default_behavior() {
628 let dir = TempDir::new().unwrap();
629 let dim = 4usize;
630 write_demo_table(&dir, dim, 4).await;
631
632 let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
633 let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
634 let cat_a: Arc<dyn CatalogProvider> =
635 Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
636 let cat_b: Arc<dyn CatalogProvider> =
637 Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
638 let table = TableIdent::new("default", "table");
639
640 let query = vec![1.0f32, 0.0, 0.0, 0.0];
641 let cfg_plain = SearchConfig {
642 top_k: 2,
643 ef_search: 50,
644 pruning_threshold: f32::INFINITY,
645 rerank_factor: None,
646 };
647 let cfg_rerank = SearchConfig {
648 top_k: 2,
649 ef_search: 50,
650 pruning_threshold: f32::INFINITY,
651 rerank_factor: Some(2),
652 };
653
654 let plain = search(
655 &table,
656 &query,
657 cfg_plain,
658 "embedding",
659 dim as u32,
660 cat_a,
661 store_a,
662 )
663 .await
664 .unwrap();
665 let reranked = search(
666 &table,
667 &query,
668 cfg_rerank,
669 "embedding",
670 dim as u32,
671 cat_b,
672 store_b,
673 )
674 .await
675 .unwrap();
676
677 assert_eq!(plain[0].row_id, reranked[0].row_id);
679 }
680}