1use std::sync::Arc;
3
4use rayon::prelude::*;
5use tracing::{debug, error};
6
7use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
8use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
9use ailake_file::AilakeFileReader;
10use ailake_index::AnyIndex;
11use ailake_store::Store;
12use ailake_vec::exact_distance;
13use arrow_array::RecordBatch;
14use bytes::Bytes;
15
16use crate::pruner::VectorPruner;
17
18#[derive(Debug, Clone)]
19pub struct SearchConfig {
20 pub top_k: usize,
21 pub ef_search: usize,
22 pub pruning_threshold: f32,
26 pub rerank_factor: Option<usize>,
31}
32
33impl Default for SearchConfig {
34 fn default() -> Self {
35 Self {
36 top_k: 10,
37 ef_search: 50,
38 pruning_threshold: f32::INFINITY,
39 rerank_factor: None,
40 }
41 }
42}
43
44impl SearchConfig {
45 pub fn with_pruning(mut self, threshold: f32) -> Self {
46 self.pruning_threshold = threshold;
47 self
48 }
49
50 pub fn with_reranking(mut self, factor: usize) -> Self {
51 self.rerank_factor = Some(factor);
52 self
53 }
54}
55
56#[derive(Debug)]
57pub struct SearchResult {
58 pub row_id: RowId,
59 pub distance: f32,
60 pub file_path: String,
61}
62
63pub async fn search(
71 table: &TableIdent,
72 query: &[f32],
73 config: SearchConfig,
74 vector_column: &str,
75 dim: u32,
76 catalog: Arc<dyn CatalogProvider>,
77 store: Arc<dyn Store>,
78) -> AilakeResult<Vec<SearchResult>> {
79 let all_files = catalog.list_files(table, None).await?;
81
82 let table_meta = catalog.load_table(table).await?;
84 let metric = parse_metric(
85 table_meta
86 .properties
87 .get("ailake.vector-metric")
88 .map(String::as_str)
89 .unwrap_or("cosine"),
90 );
91
92 let total_files = all_files.len();
94 let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
95 debug!(
96 "ailake: geometric pruning — {}/{} files survive (threshold={})",
97 surviving_files.len(),
98 total_files,
99 config.pruning_threshold
100 );
101
102 let candidate_k = match config.rerank_factor {
103 Some(factor) => config.top_k * factor,
104 None => config.top_k,
105 };
106
107 let mut all_results: Vec<SearchResult> = Vec::new();
108
109 for file_entry in &surviving_files {
110 let file_bytes: Bytes = store.get(&file_entry.path).await?;
111 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
112
113 if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
114 debug!(
116 "ailake: flat scan fallback for {} (index_status={:?})",
117 file_entry.path, file_entry.index_status
118 );
119 let (_, raw_vectors) = reader.read_parquet()?;
120 for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
121 all_results.push(SearchResult {
122 row_id,
123 distance,
124 file_path: file_entry.path.clone(),
125 });
126 }
127 continue;
128 }
129
130 let index = reader.load_any_index_for_column(vector_column)?;
131 let local_results = index.search(query, candidate_k, config.ef_search);
132
133 if config.rerank_factor.is_some() {
134 let (_, raw_vectors) = reader.read_parquet()?;
136 for (row_id, _approx_dist) in local_results {
137 let idx = row_id.as_u64() as usize;
138 let exact_dist = match raw_vectors.get(idx) {
139 Some(v) => exact_distance(metric, query, v),
140 None => {
141 error!(
142 "ailake: invariant violated — row_id {} out of bounds \
143 (raw_vectors.len={}, file={}); \
144 Parquet row count and HNSW node count are out of sync; \
145 file may be corrupt — run compaction to rebuild",
146 idx,
147 raw_vectors.len(),
148 file_entry.path
149 );
150 f32::INFINITY
151 }
152 };
153 all_results.push(SearchResult {
154 row_id,
155 distance: exact_dist,
156 file_path: file_entry.path.clone(),
157 });
158 }
159 } else {
160 for (row_id, distance) in local_results {
161 all_results.push(SearchResult {
162 row_id,
163 distance,
164 file_path: file_entry.path.clone(),
165 });
166 }
167 }
168 }
169
170 all_results.sort_by(|a, b| {
172 a.distance
173 .partial_cmp(&b.distance)
174 .unwrap_or(std::cmp::Ordering::Equal)
175 });
176 all_results.truncate(config.top_k);
177 Ok(all_results)
178}
179
180fn flat_search(
182 raw: &[Vec<f32>],
183 query: &[f32],
184 top_k: usize,
185 metric: VectorMetric,
186) -> Vec<(RowId, f32)> {
187 let mut results: Vec<(RowId, f32)> = raw
188 .iter()
189 .enumerate()
190 .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
191 .collect();
192 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
193 results.truncate(top_k);
194 results
195}
196
197fn parse_metric(s: &str) -> VectorMetric {
198 match s {
199 "euclidean" => VectorMetric::Euclidean,
200 "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
201 _ => VectorMetric::Cosine,
202 }
203}
204
205pub struct SearchSession {
210 shards: Vec<LoadedShard>,
211 metric: VectorMetric,
212}
213
214struct LoadedShard {
215 entry: DataFileEntry,
216 index: Option<AnyIndex>,
218 raw_vectors: Option<Vec<Vec<f32>>>,
221}
222
223impl SearchSession {
224 pub async fn load(
230 table: &TableIdent,
231 vector_column: &str,
232 dim: u32,
233 catalog: Arc<dyn CatalogProvider>,
234 store: Arc<dyn Store>,
235 load_raw: bool,
236 ) -> AilakeResult<Self> {
237 let all_files = catalog.list_files(table, None).await?;
238 let table_meta = catalog.load_table(table).await?;
239 let metric = parse_metric(
240 table_meta
241 .properties
242 .get("ailake.vector-metric")
243 .map(String::as_str)
244 .unwrap_or("cosine"),
245 );
246
247 let mut shards = Vec::with_capacity(all_files.len());
248 for entry in all_files {
249 let file_bytes: Bytes = store.get(&entry.path).await?;
250 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
251
252 if entry.index_status == IndexStatus::Indexing {
253 let (_, raw_vecs) = reader.read_parquet()?;
255 shards.push(LoadedShard {
256 entry,
257 index: None,
258 raw_vectors: Some(raw_vecs),
259 });
260 } else if reader.is_ailake_file() {
261 let mut index = reader.load_any_index_for_column(vector_column)?;
262 let raw_vectors = if load_raw {
263 index.quantize_to_f16();
264 let (_, vecs) = reader.read_parquet()?;
265 Some(vecs)
266 } else {
267 None
268 };
269 shards.push(LoadedShard {
270 entry,
271 index: Some(index),
272 raw_vectors,
273 });
274 }
275 }
276
277 Ok(Self { shards, metric })
278 }
279
280 pub fn shard_count(&self) -> usize {
282 self.shards.len()
283 }
284
285 pub fn search_batch(
294 &self,
295 queries: &[Vec<f32>],
296 config: &SearchConfig,
297 ) -> Vec<Vec<SearchResult>> {
298 if queries.is_empty() {
299 return vec![];
300 }
301
302 let n_queries = queries.len();
303 let candidate_k = match config.rerank_factor {
304 Some(factor) => config.top_k * factor,
305 None => config.top_k,
306 };
307 let use_nvidia = ailake_index::hardware::detect_cuda();
308 let use_amd = ailake_index::hardware::detect_rocm();
309
310 let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
312
313 for shard in &self.shards {
314 if let Some(raw) = &shard.raw_vectors {
315 if !raw.is_empty() {
317 let dim = raw[0].len();
318 let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
319 let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
320 let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
321
322 let gpu_batch = if use_nvidia {
323 ailake_index::gpu::try_nvidia_search_batch(
324 &q_refs,
325 &row_ids,
326 &flat,
327 dim,
328 self.metric,
329 candidate_k,
330 )
331 } else if use_amd {
332 ailake_index::gpu::try_rocm_search_batch(
333 &q_refs,
334 &row_ids,
335 &flat,
336 dim,
337 self.metric,
338 candidate_k,
339 )
340 } else {
341 None
342 };
343
344 if let Some(batch) = gpu_batch {
345 for (qi, results) in batch.into_iter().enumerate() {
346 for (row_id, distance) in results {
347 all_results[qi].push(SearchResult {
348 row_id,
349 distance,
350 file_path: shard.entry.path.clone(),
351 });
352 }
353 }
354 continue;
355 }
356 }
357
358 for (qi, query) in queries.iter().enumerate() {
360 for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
361 all_results[qi].push(SearchResult {
362 row_id,
363 distance,
364 file_path: shard.entry.path.clone(),
365 });
366 }
367 }
368 } else if let Some(index) = &shard.index {
369 let shard_results: Vec<Vec<SearchResult>> = queries
371 .par_iter()
372 .map(|query| {
373 index
374 .search(query, candidate_k, config.ef_search)
375 .into_iter()
376 .map(|(row_id, distance)| SearchResult {
377 row_id,
378 distance,
379 file_path: shard.entry.path.clone(),
380 })
381 .collect()
382 })
383 .collect();
384
385 for (qi, results) in shard_results.into_iter().enumerate() {
386 all_results[qi].extend(results);
387 }
388 }
389 }
390
391 for results in &mut all_results {
393 results.sort_by(|a, b| {
394 a.distance
395 .partial_cmp(&b.distance)
396 .unwrap_or(std::cmp::Ordering::Equal)
397 });
398 results.truncate(config.top_k);
399 }
400
401 all_results
402 }
403
404 pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
406 let candidate_k = match config.rerank_factor {
407 Some(factor) => config.top_k * factor,
408 None => config.top_k,
409 };
410
411 let mut all_results: Vec<SearchResult> = self
412 .shards
413 .par_iter()
414 .flat_map(|shard| {
415 if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
417 let dist = match self.metric {
418 VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
419 ailake_vec::cosine_distance(query, ¢roid.values)
420 }
421 VectorMetric::Euclidean => {
422 ailake_vec::euclidean_distance(query, ¢roid.values)
423 }
424 VectorMetric::DotProduct => {
425 -ailake_vec::dot_product(query, ¢roid.values)
426 }
427 };
428 if dist - centroid.radius > config.pruning_threshold {
429 return vec![];
430 }
431 }
432
433 if let Some(index) = &shard.index {
434 let local_results = index.search(query, candidate_k, config.ef_search);
436 if config.rerank_factor.is_some() {
437 if let Some(raw) = &shard.raw_vectors {
438 local_results
439 .into_iter()
440 .map(|(row_id, _approx_dist)| {
441 let idx = row_id.as_u64() as usize;
442 let exact_dist = raw
443 .get(idx)
444 .map(|v| exact_distance(self.metric, query, v))
445 .unwrap_or(f32::INFINITY);
446 SearchResult {
447 row_id,
448 distance: exact_dist,
449 file_path: shard.entry.path.clone(),
450 }
451 })
452 .collect()
453 } else {
454 local_results
455 .into_iter()
456 .map(|(row_id, distance)| SearchResult {
457 row_id,
458 distance,
459 file_path: shard.entry.path.clone(),
460 })
461 .collect()
462 }
463 } else {
464 local_results
465 .into_iter()
466 .map(|(row_id, distance)| SearchResult {
467 row_id,
468 distance,
469 file_path: shard.entry.path.clone(),
470 })
471 .collect()
472 }
473 } else if let Some(raw) = &shard.raw_vectors {
474 flat_search(raw, query, candidate_k, self.metric)
476 .into_iter()
477 .map(|(row_id, distance)| SearchResult {
478 row_id,
479 distance,
480 file_path: shard.entry.path.clone(),
481 })
482 .collect()
483 } else {
484 vec![]
485 }
486 })
487 .collect();
488
489 all_results.sort_by(|a, b| {
490 a.distance
491 .partial_cmp(&b.distance)
492 .unwrap_or(std::cmp::Ordering::Equal)
493 });
494 all_results.truncate(config.top_k);
495 all_results
496 }
497}
498
499pub async fn fetch_rows(
508 results: &[SearchResult],
509 store: Arc<dyn Store>,
510 vector_column: &str,
511 dim: u32,
512) -> AilakeResult<RecordBatch> {
513 use std::collections::HashMap;
514
515 use arrow_array::{ArrayRef, Float32Array, UInt32Array};
516 use arrow_schema::{DataType, Field, Schema};
517 use arrow_select::{concat::concat_batches, take::take};
518
519 if results.is_empty() {
520 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
521 }
522
523 let mut by_file: HashMap<&str, Vec<(u64, f32, usize)>> = HashMap::new();
525 for (i, r) in results.iter().enumerate() {
526 by_file
527 .entry(r.file_path.as_str())
528 .or_default()
529 .push((r.row_id.as_u64(), r.distance, i));
530 }
531
532 use arrow_array::FixedSizeListArray;
533
534 let mut collected: Vec<(usize, f32, RecordBatch, Vec<f32>)> = Vec::with_capacity(results.len());
536
537 for (file_path, rows) in &by_file {
538 let bytes = store.get(file_path).await?;
539 let reader = AilakeFileReader::new(bytes, vector_column, dim);
540 let (batch, vectors) = reader.read_parquet()?;
541
542 for &(row_id, distance, pos) in rows {
543 let idx = row_id as usize;
544 if idx >= batch.num_rows() {
545 tracing::warn!(
546 "fetch_rows: row_id {} out of bounds (file_rows={}, file={}), skipping",
547 idx,
548 batch.num_rows(),
549 file_path
550 );
551 continue;
552 }
553
554 let indices = UInt32Array::from(vec![idx as u32]);
555 let row_cols: Vec<ArrayRef> = batch
556 .columns()
557 .iter()
558 .map(|col| {
559 take(col.as_ref(), &indices, None)
560 .map_err(|e| AilakeError::Arrow(e.to_string()))
561 })
562 .collect::<AilakeResult<Vec<_>>>()?;
563
564 let row_batch = RecordBatch::try_new(batch.schema(), row_cols)
565 .map_err(|e| AilakeError::Arrow(e.to_string()))?;
566
567 let vec = vectors
569 .get(idx)
570 .cloned()
571 .unwrap_or_else(|| vec![0.0f32; dim as usize]);
572
573 collected.push((pos, distance, row_batch, vec));
574 }
575 }
576
577 if collected.is_empty() {
578 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
579 }
580
581 collected.sort_by_key(|(pos, _, _, _)| *pos);
583
584 let distances: Vec<f32> = collected.iter().map(|(_, d, _, _)| *d).collect();
585 let row_batches: Vec<&RecordBatch> = collected.iter().map(|(_, _, b, _)| b).collect();
586 let base_schema = collected[0].2.schema();
587
588 let combined =
589 concat_batches(&base_schema, row_batches).map_err(|e| AilakeError::Arrow(e.to_string()))?;
590
591 let flat_vecs: Vec<f32> = collected
593 .iter()
594 .flat_map(|(_, _, _, v)| v.iter().copied())
595 .collect();
596 let item_field = Arc::new(Field::new("item", DataType::Float32, false));
597 let values_arr = Arc::new(Float32Array::from(flat_vecs)) as ArrayRef;
598 let vec_col = FixedSizeListArray::new(item_field.clone(), dim as i32, values_arr, None);
599 let vec_field = Arc::new(Field::new(
600 vector_column,
601 DataType::FixedSizeList(item_field, dim as i32),
602 false,
603 ));
604
605 let mut fields: Vec<Arc<Field>> = base_schema.fields().to_vec();
607 fields.push(vec_field);
608 fields.push(Arc::new(Field::new("_distance", DataType::Float32, false)));
609 let new_schema = Arc::new(Schema::new(fields));
610
611 let mut columns: Vec<ArrayRef> = combined.columns().to_vec();
612 columns.push(Arc::new(vec_col));
613 columns.push(Arc::new(Float32Array::from(distances)));
614
615 RecordBatch::try_new(new_schema, columns).map_err(|e| AilakeError::Arrow(e.to_string()))
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621 use ailake_catalog::{HadoopCatalog, TableIdent};
622 use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
623 use ailake_store::LocalStore;
624 use arrow_array::{Int32Array, RecordBatch};
625 use arrow_schema::{DataType, Field, Schema};
626 use std::sync::Arc;
627 use tempfile::TempDir;
628
629 fn make_policy(dim: u32) -> VectorStoragePolicy {
630 VectorStoragePolicy {
631 column_name: "embedding".to_string(),
632 dim,
633 metric: VectorMetric::Cosine,
634 precision: VectorPrecision::F16,
635 pq: None,
636 keep_raw_for_reranking: true,
637 pre_normalize: false,
638 hnsw_m: None,
639 hnsw_ef_construction: None,
640 }
641 }
642
643 async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
644 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
645 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
646 let table = TableIdent::new("default", "table");
647
648 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
649 let ids: Vec<i32> = (0..rows as i32).collect();
650 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
651
652 let embeddings: Vec<Vec<f32>> = (0..rows)
654 .map(|i| {
655 let mut v = vec![0.0f32; dim];
656 v[i % dim] = 1.0;
657 v
658 })
659 .collect();
660
661 let mut writer =
662 crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
663 .await
664 .unwrap();
665 writer.write_batch(&batch, &embeddings).await.unwrap();
666 writer.commit().await.unwrap();
667 }
668
669 #[tokio::test]
670 async fn rerank_returns_correct_top_k_count() {
671 let dir = TempDir::new().unwrap();
672 let dim = 8usize;
673 write_demo_table(&dir, dim, 8).await;
674
675 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
676 let catalog: Arc<dyn CatalogProvider> =
677 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
678 let table = TableIdent::new("default", "table");
679
680 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
681 let config = SearchConfig {
682 top_k: 3,
683 ef_search: 50,
684 pruning_threshold: f32::INFINITY,
685 rerank_factor: Some(2),
686 };
687
688 let results = search(
689 &table,
690 &query,
691 config,
692 "embedding",
693 dim as u32,
694 catalog,
695 store,
696 )
697 .await
698 .unwrap();
699
700 assert_eq!(results.len(), 3);
701 }
702
703 #[tokio::test]
704 async fn rerank_nearest_is_exact_match() {
705 let dir = TempDir::new().unwrap();
706 let dim = 8usize;
707 write_demo_table(&dir, dim, 8).await;
708
709 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
710 let catalog: Arc<dyn CatalogProvider> =
711 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
712 let table = TableIdent::new("default", "table");
713
714 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
716 let config = SearchConfig {
717 top_k: 1,
718 ef_search: 50,
719 pruning_threshold: f32::INFINITY,
720 rerank_factor: Some(4),
721 };
722
723 let results = search(
724 &table,
725 &query,
726 config,
727 "embedding",
728 dim as u32,
729 catalog,
730 store,
731 )
732 .await
733 .unwrap();
734
735 assert_eq!(results.len(), 1);
736 assert!(
738 results[0].distance < 1e-3,
739 "distance was {}",
740 results[0].distance
741 );
742 assert_eq!(results[0].row_id, RowId::new(0));
743 }
744
745 #[tokio::test]
746 async fn no_rerank_matches_default_behavior() {
747 let dir = TempDir::new().unwrap();
748 let dim = 4usize;
749 write_demo_table(&dir, dim, 4).await;
750
751 let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
752 let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
753 let cat_a: Arc<dyn CatalogProvider> =
754 Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
755 let cat_b: Arc<dyn CatalogProvider> =
756 Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
757 let table = TableIdent::new("default", "table");
758
759 let query = vec![1.0f32, 0.0, 0.0, 0.0];
760 let cfg_plain = SearchConfig {
761 top_k: 2,
762 ef_search: 50,
763 pruning_threshold: f32::INFINITY,
764 rerank_factor: None,
765 };
766 let cfg_rerank = SearchConfig {
767 top_k: 2,
768 ef_search: 50,
769 pruning_threshold: f32::INFINITY,
770 rerank_factor: Some(2),
771 };
772
773 let plain = search(
774 &table,
775 &query,
776 cfg_plain,
777 "embedding",
778 dim as u32,
779 cat_a,
780 store_a,
781 )
782 .await
783 .unwrap();
784 let reranked = search(
785 &table,
786 &query,
787 cfg_rerank,
788 "embedding",
789 dim as u32,
790 cat_b,
791 store_b,
792 )
793 .await
794 .unwrap();
795
796 assert_eq!(plain[0].row_id, reranked[0].row_id);
798 }
799}