1use std::sync::Arc;
3
4use rayon::prelude::*;
5use tracing::{debug, error};
6
7use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
8use ailake_core::{AilakeError, AilakeResult, EmbeddingModelInfo, RowId, VectorMetric};
9use ailake_file::AilakeFileReader;
10use ailake_index::AnyIndex;
11use ailake_store::Store;
12use ailake_vec::exact_distance;
13use arrow_array::RecordBatch;
14use bytes::Bytes;
15
16use crate::pruner::VectorPruner;
17
18#[derive(Debug, Clone)]
19pub struct SearchConfig {
20 pub top_k: usize,
21 pub ef_search: usize,
22 pub pruning_threshold: f32,
26 pub rerank_factor: Option<usize>,
31}
32
33impl Default for SearchConfig {
34 fn default() -> Self {
35 Self {
36 top_k: 10,
37 ef_search: 50,
38 pruning_threshold: f32::INFINITY,
39 rerank_factor: None,
40 }
41 }
42}
43
44impl SearchConfig {
45 pub fn with_pruning(mut self, threshold: f32) -> Self {
46 self.pruning_threshold = threshold;
47 self
48 }
49
50 pub fn with_reranking(mut self, factor: usize) -> Self {
51 self.rerank_factor = Some(factor);
52 self
53 }
54}
55
56#[derive(Debug)]
57pub struct SearchResult {
58 pub row_id: RowId,
59 pub distance: f32,
60 pub file_path: String,
61}
62
63pub async fn search(
71 table: &TableIdent,
72 query: &[f32],
73 config: SearchConfig,
74 vector_column: &str,
75 dim: u32,
76 catalog: Arc<dyn CatalogProvider>,
77 store: Arc<dyn Store>,
78) -> AilakeResult<Vec<SearchResult>> {
79 let all_files = catalog.list_files(table, None).await?;
81
82 let table_meta = catalog.load_table(table).await?;
84
85 if let Some(table_dim_str) = table_meta.properties.get("ailake.vector-dim") {
87 if let Ok(table_dim) = table_dim_str.parse::<u32>() {
88 let query_dim = query.len() as u32;
89 if query_dim != table_dim {
90 let table_model = table_meta
91 .properties
92 .get(EmbeddingModelInfo::property_key())
93 .cloned()
94 .unwrap_or_else(|| format!("dim={}", table_dim));
95 return Err(AilakeError::ModelMismatch {
96 table_model,
97 table_dim,
98 batch_model: format!("query dim={}", query_dim),
99 batch_dim: query_dim,
100 });
101 }
102 }
103 }
104
105 let metric = parse_metric(
106 table_meta
107 .properties
108 .get("ailake.vector-metric")
109 .map(String::as_str)
110 .unwrap_or("cosine"),
111 );
112
113 let total_files = all_files.len();
115 let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
116 debug!(
117 "ailake: geometric pruning — {}/{} files survive (threshold={})",
118 surviving_files.len(),
119 total_files,
120 config.pruning_threshold
121 );
122
123 let candidate_k = match config.rerank_factor {
124 Some(factor) => config.top_k * factor,
125 None => config.top_k,
126 };
127
128 let mut all_results: Vec<SearchResult> = Vec::new();
129
130 for file_entry in &surviving_files {
131 let file_bytes: Bytes = store.get(&file_entry.path).await?;
132 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
133
134 if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
135 debug!(
137 "ailake: flat scan fallback for {} (index_status={:?})",
138 file_entry.path, file_entry.index_status
139 );
140 let (_, raw_vectors) = reader.read_parquet()?;
141 for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
142 all_results.push(SearchResult {
143 row_id,
144 distance,
145 file_path: file_entry.path.clone(),
146 });
147 }
148 continue;
149 }
150
151 let index = reader.load_any_index_for_column(vector_column)?;
152 let local_results = index.search(query, candidate_k, config.ef_search);
153
154 if config.rerank_factor.is_some() {
155 let (_, raw_vectors) = reader.read_parquet()?;
157 for (row_id, _approx_dist) in local_results {
158 let idx = row_id.as_u64() as usize;
159 let exact_dist = match raw_vectors.get(idx) {
160 Some(v) => exact_distance(metric, query, v),
161 None => {
162 error!(
163 "ailake: invariant violated — row_id {} out of bounds \
164 (raw_vectors.len={}, file={}); \
165 Parquet row count and HNSW node count are out of sync; \
166 file may be corrupt — run compaction to rebuild",
167 idx,
168 raw_vectors.len(),
169 file_entry.path
170 );
171 f32::INFINITY
172 }
173 };
174 all_results.push(SearchResult {
175 row_id,
176 distance: exact_dist,
177 file_path: file_entry.path.clone(),
178 });
179 }
180 } else {
181 for (row_id, distance) in local_results {
182 all_results.push(SearchResult {
183 row_id,
184 distance,
185 file_path: file_entry.path.clone(),
186 });
187 }
188 }
189 }
190
191 all_results.sort_by(|a, b| {
193 a.distance
194 .partial_cmp(&b.distance)
195 .unwrap_or(std::cmp::Ordering::Equal)
196 });
197 all_results.truncate(config.top_k);
198 Ok(all_results)
199}
200
201fn flat_search(
203 raw: &[Vec<f32>],
204 query: &[f32],
205 top_k: usize,
206 metric: VectorMetric,
207) -> Vec<(RowId, f32)> {
208 let mut results: Vec<(RowId, f32)> = raw
209 .iter()
210 .enumerate()
211 .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
212 .collect();
213 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
214 results.truncate(top_k);
215 results
216}
217
218fn parse_metric(s: &str) -> VectorMetric {
219 match s {
220 "euclidean" => VectorMetric::Euclidean,
221 "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
222 _ => VectorMetric::Cosine,
223 }
224}
225
226pub struct SearchSession {
231 shards: Vec<LoadedShard>,
232 metric: VectorMetric,
233}
234
235struct LoadedShard {
236 entry: DataFileEntry,
237 index: Option<AnyIndex>,
239 raw_vectors: Option<Vec<Vec<f32>>>,
242}
243
244impl SearchSession {
245 pub async fn load(
251 table: &TableIdent,
252 vector_column: &str,
253 dim: u32,
254 catalog: Arc<dyn CatalogProvider>,
255 store: Arc<dyn Store>,
256 load_raw: bool,
257 ) -> AilakeResult<Self> {
258 let all_files = catalog.list_files(table, None).await?;
259 let table_meta = catalog.load_table(table).await?;
260 let metric = parse_metric(
261 table_meta
262 .properties
263 .get("ailake.vector-metric")
264 .map(String::as_str)
265 .unwrap_or("cosine"),
266 );
267
268 let mut shards = Vec::with_capacity(all_files.len());
269 for entry in all_files {
270 let file_bytes: Bytes = store.get(&entry.path).await?;
271 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
272
273 if entry.index_status == IndexStatus::Indexing {
274 let (_, raw_vecs) = reader.read_parquet()?;
276 shards.push(LoadedShard {
277 entry,
278 index: None,
279 raw_vectors: Some(raw_vecs),
280 });
281 } else if reader.is_ailake_file() {
282 let mut index = reader.load_any_index_for_column(vector_column)?;
283 let raw_vectors = if load_raw {
284 index.quantize_to_f16();
285 let (_, vecs) = reader.read_parquet()?;
286 Some(vecs)
287 } else {
288 None
289 };
290 shards.push(LoadedShard {
291 entry,
292 index: Some(index),
293 raw_vectors,
294 });
295 }
296 }
297
298 Ok(Self { shards, metric })
299 }
300
301 pub fn shard_count(&self) -> usize {
303 self.shards.len()
304 }
305
306 pub fn search_batch(
315 &self,
316 queries: &[Vec<f32>],
317 config: &SearchConfig,
318 ) -> Vec<Vec<SearchResult>> {
319 if queries.is_empty() {
320 return vec![];
321 }
322
323 let n_queries = queries.len();
324 let candidate_k = match config.rerank_factor {
325 Some(factor) => config.top_k * factor,
326 None => config.top_k,
327 };
328 let use_nvidia = ailake_index::hardware::detect_cuda();
329 let use_amd = ailake_index::hardware::detect_rocm();
330
331 let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
333
334 for shard in &self.shards {
335 if let Some(raw) = &shard.raw_vectors {
336 if !raw.is_empty() {
338 let dim = raw[0].len();
339 let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
340 let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
341 let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
342
343 let gpu_batch = if use_nvidia {
344 ailake_index::gpu::try_nvidia_search_batch(
345 &q_refs,
346 &row_ids,
347 &flat,
348 dim,
349 self.metric,
350 candidate_k,
351 )
352 } else if use_amd {
353 ailake_index::gpu::try_rocm_search_batch(
354 &q_refs,
355 &row_ids,
356 &flat,
357 dim,
358 self.metric,
359 candidate_k,
360 )
361 } else {
362 None
363 };
364
365 if let Some(batch) = gpu_batch {
366 for (qi, results) in batch.into_iter().enumerate() {
367 for (row_id, distance) in results {
368 all_results[qi].push(SearchResult {
369 row_id,
370 distance,
371 file_path: shard.entry.path.clone(),
372 });
373 }
374 }
375 continue;
376 }
377 }
378
379 for (qi, query) in queries.iter().enumerate() {
381 for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
382 all_results[qi].push(SearchResult {
383 row_id,
384 distance,
385 file_path: shard.entry.path.clone(),
386 });
387 }
388 }
389 } else if let Some(index) = &shard.index {
390 let shard_results: Vec<Vec<SearchResult>> = queries
392 .par_iter()
393 .map(|query| {
394 index
395 .search(query, candidate_k, config.ef_search)
396 .into_iter()
397 .map(|(row_id, distance)| SearchResult {
398 row_id,
399 distance,
400 file_path: shard.entry.path.clone(),
401 })
402 .collect()
403 })
404 .collect();
405
406 for (qi, results) in shard_results.into_iter().enumerate() {
407 all_results[qi].extend(results);
408 }
409 }
410 }
411
412 for results in &mut all_results {
414 results.sort_by(|a, b| {
415 a.distance
416 .partial_cmp(&b.distance)
417 .unwrap_or(std::cmp::Ordering::Equal)
418 });
419 results.truncate(config.top_k);
420 }
421
422 all_results
423 }
424
425 pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
427 let candidate_k = match config.rerank_factor {
428 Some(factor) => config.top_k * factor,
429 None => config.top_k,
430 };
431
432 let mut all_results: Vec<SearchResult> = self
433 .shards
434 .par_iter()
435 .flat_map(|shard| {
436 if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
438 let dist = match self.metric {
439 VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
440 ailake_vec::cosine_distance(query, ¢roid.values)
441 }
442 VectorMetric::Euclidean => {
443 ailake_vec::euclidean_distance(query, ¢roid.values)
444 }
445 VectorMetric::DotProduct => {
446 -ailake_vec::dot_product(query, ¢roid.values)
447 }
448 };
449 if dist - centroid.radius > config.pruning_threshold {
450 return vec![];
451 }
452 }
453
454 if let Some(index) = &shard.index {
455 let local_results = index.search(query, candidate_k, config.ef_search);
457 if config.rerank_factor.is_some() {
458 if let Some(raw) = &shard.raw_vectors {
459 local_results
460 .into_iter()
461 .map(|(row_id, _approx_dist)| {
462 let idx = row_id.as_u64() as usize;
463 let exact_dist = raw
464 .get(idx)
465 .map(|v| exact_distance(self.metric, query, v))
466 .unwrap_or(f32::INFINITY);
467 SearchResult {
468 row_id,
469 distance: exact_dist,
470 file_path: shard.entry.path.clone(),
471 }
472 })
473 .collect()
474 } else {
475 local_results
476 .into_iter()
477 .map(|(row_id, distance)| SearchResult {
478 row_id,
479 distance,
480 file_path: shard.entry.path.clone(),
481 })
482 .collect()
483 }
484 } else {
485 local_results
486 .into_iter()
487 .map(|(row_id, distance)| SearchResult {
488 row_id,
489 distance,
490 file_path: shard.entry.path.clone(),
491 })
492 .collect()
493 }
494 } else if let Some(raw) = &shard.raw_vectors {
495 flat_search(raw, query, candidate_k, self.metric)
497 .into_iter()
498 .map(|(row_id, distance)| SearchResult {
499 row_id,
500 distance,
501 file_path: shard.entry.path.clone(),
502 })
503 .collect()
504 } else {
505 vec![]
506 }
507 })
508 .collect();
509
510 all_results.sort_by(|a, b| {
511 a.distance
512 .partial_cmp(&b.distance)
513 .unwrap_or(std::cmp::Ordering::Equal)
514 });
515 all_results.truncate(config.top_k);
516 all_results
517 }
518}
519
520pub async fn fetch_rows(
529 results: &[SearchResult],
530 store: Arc<dyn Store>,
531 vector_column: &str,
532 dim: u32,
533) -> AilakeResult<RecordBatch> {
534 use std::collections::HashMap;
535
536 use arrow_array::{ArrayRef, Float32Array, UInt32Array};
537 use arrow_schema::{DataType, Field, Schema};
538 use arrow_select::{concat::concat_batches, take::take};
539
540 if results.is_empty() {
541 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
542 }
543
544 let mut by_file: HashMap<&str, Vec<(u64, f32, usize)>> = HashMap::new();
546 for (i, r) in results.iter().enumerate() {
547 by_file
548 .entry(r.file_path.as_str())
549 .or_default()
550 .push((r.row_id.as_u64(), r.distance, i));
551 }
552
553 use arrow_array::FixedSizeListArray;
554
555 let mut collected: Vec<(usize, f32, RecordBatch, Vec<f32>)> = Vec::with_capacity(results.len());
557
558 for (file_path, rows) in &by_file {
559 let bytes = store.get(file_path).await?;
560 let reader = AilakeFileReader::new(bytes, vector_column, dim);
561 let (batch, vectors) = reader.read_parquet()?;
562
563 for &(row_id, distance, pos) in rows {
564 let idx = row_id as usize;
565 if idx >= batch.num_rows() {
566 tracing::warn!(
567 "fetch_rows: row_id {} out of bounds (file_rows={}, file={}), skipping",
568 idx,
569 batch.num_rows(),
570 file_path
571 );
572 continue;
573 }
574
575 let indices = UInt32Array::from(vec![idx as u32]);
576 let row_cols: Vec<ArrayRef> = batch
577 .columns()
578 .iter()
579 .map(|col| {
580 take(col.as_ref(), &indices, None)
581 .map_err(|e| AilakeError::Arrow(e.to_string()))
582 })
583 .collect::<AilakeResult<Vec<_>>>()?;
584
585 let row_batch = RecordBatch::try_new(batch.schema(), row_cols)
586 .map_err(|e| AilakeError::Arrow(e.to_string()))?;
587
588 let vec = vectors
590 .get(idx)
591 .cloned()
592 .unwrap_or_else(|| vec![0.0f32; dim as usize]);
593
594 collected.push((pos, distance, row_batch, vec));
595 }
596 }
597
598 if collected.is_empty() {
599 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
600 }
601
602 collected.sort_by_key(|(pos, _, _, _)| *pos);
604
605 let distances: Vec<f32> = collected.iter().map(|(_, d, _, _)| *d).collect();
606 let row_batches: Vec<&RecordBatch> = collected.iter().map(|(_, _, b, _)| b).collect();
607 let base_schema = collected[0].2.schema();
608
609 let combined =
610 concat_batches(&base_schema, row_batches).map_err(|e| AilakeError::Arrow(e.to_string()))?;
611
612 let flat_vecs: Vec<f32> = collected
614 .iter()
615 .flat_map(|(_, _, _, v)| v.iter().copied())
616 .collect();
617 let item_field = Arc::new(Field::new("item", DataType::Float32, false));
618 let values_arr = Arc::new(Float32Array::from(flat_vecs)) as ArrayRef;
619 let vec_col = FixedSizeListArray::new(item_field.clone(), dim as i32, values_arr, None);
620 let vec_field = Arc::new(Field::new(
621 vector_column,
622 DataType::FixedSizeList(item_field, dim as i32),
623 false,
624 ));
625
626 let mut fields: Vec<Arc<Field>> = base_schema.fields().to_vec();
628 fields.push(vec_field);
629 fields.push(Arc::new(Field::new("_distance", DataType::Float32, false)));
630 let new_schema = Arc::new(Schema::new(fields));
631
632 let mut columns: Vec<ArrayRef> = combined.columns().to_vec();
633 columns.push(Arc::new(vec_col));
634 columns.push(Arc::new(Float32Array::from(distances)));
635
636 RecordBatch::try_new(new_schema, columns).map_err(|e| AilakeError::Arrow(e.to_string()))
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use ailake_catalog::{HadoopCatalog, TableIdent};
643 use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
644 use ailake_store::LocalStore;
645 use arrow_array::{Int32Array, RecordBatch};
646 use arrow_schema::{DataType, Field, Schema};
647 use std::sync::Arc;
648 use tempfile::TempDir;
649
650 fn make_policy(dim: u32) -> VectorStoragePolicy {
651 VectorStoragePolicy {
652 column_name: "embedding".to_string(),
653 dim,
654 metric: VectorMetric::Cosine,
655 precision: VectorPrecision::F16,
656 pq: None,
657 keep_raw_for_reranking: true,
658 pre_normalize: false,
659 hnsw_m: None,
660 hnsw_ef_construction: None,
661 ivf_residual: false,
662 embedding_model: None,
663 }
664 }
665
666 async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
667 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
668 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
669 let table = TableIdent::new("default", "table");
670
671 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
672 let ids: Vec<i32> = (0..rows as i32).collect();
673 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
674
675 let embeddings: Vec<Vec<f32>> = (0..rows)
677 .map(|i| {
678 let mut v = vec![0.0f32; dim];
679 v[i % dim] = 1.0;
680 v
681 })
682 .collect();
683
684 let mut writer =
685 crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
686 .await
687 .unwrap();
688 writer.write_batch(&batch, &embeddings).await.unwrap();
689 writer.commit().await.unwrap();
690 }
691
692 #[tokio::test]
693 async fn rerank_returns_correct_top_k_count() {
694 let dir = TempDir::new().unwrap();
695 let dim = 8usize;
696 write_demo_table(&dir, dim, 8).await;
697
698 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
699 let catalog: Arc<dyn CatalogProvider> =
700 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
701 let table = TableIdent::new("default", "table");
702
703 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
704 let config = SearchConfig {
705 top_k: 3,
706 ef_search: 50,
707 pruning_threshold: f32::INFINITY,
708 rerank_factor: Some(2),
709 };
710
711 let results = search(
712 &table,
713 &query,
714 config,
715 "embedding",
716 dim as u32,
717 catalog,
718 store,
719 )
720 .await
721 .unwrap();
722
723 assert_eq!(results.len(), 3);
724 }
725
726 #[tokio::test]
727 async fn rerank_nearest_is_exact_match() {
728 let dir = TempDir::new().unwrap();
729 let dim = 8usize;
730 write_demo_table(&dir, dim, 8).await;
731
732 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
733 let catalog: Arc<dyn CatalogProvider> =
734 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
735 let table = TableIdent::new("default", "table");
736
737 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
739 let config = SearchConfig {
740 top_k: 1,
741 ef_search: 50,
742 pruning_threshold: f32::INFINITY,
743 rerank_factor: Some(4),
744 };
745
746 let results = search(
747 &table,
748 &query,
749 config,
750 "embedding",
751 dim as u32,
752 catalog,
753 store,
754 )
755 .await
756 .unwrap();
757
758 assert_eq!(results.len(), 1);
759 assert!(
761 results[0].distance < 1e-3,
762 "distance was {}",
763 results[0].distance
764 );
765 assert_eq!(results[0].row_id, RowId::new(0));
766 }
767
768 #[tokio::test]
769 async fn no_rerank_matches_default_behavior() {
770 let dir = TempDir::new().unwrap();
771 let dim = 4usize;
772 write_demo_table(&dir, dim, 4).await;
773
774 let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
775 let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
776 let cat_a: Arc<dyn CatalogProvider> =
777 Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
778 let cat_b: Arc<dyn CatalogProvider> =
779 Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
780 let table = TableIdent::new("default", "table");
781
782 let query = vec![1.0f32, 0.0, 0.0, 0.0];
783 let cfg_plain = SearchConfig {
784 top_k: 2,
785 ef_search: 50,
786 pruning_threshold: f32::INFINITY,
787 rerank_factor: None,
788 };
789 let cfg_rerank = SearchConfig {
790 top_k: 2,
791 ef_search: 50,
792 pruning_threshold: f32::INFINITY,
793 rerank_factor: Some(2),
794 };
795
796 let plain = search(
797 &table,
798 &query,
799 cfg_plain,
800 "embedding",
801 dim as u32,
802 cat_a,
803 store_a,
804 )
805 .await
806 .unwrap();
807 let reranked = search(
808 &table,
809 &query,
810 cfg_rerank,
811 "embedding",
812 dim as u32,
813 cat_b,
814 store_b,
815 )
816 .await
817 .unwrap();
818
819 assert_eq!(plain[0].row_id, reranked[0].row_id);
821 }
822}