1use std::sync::Arc;
2
3use rayon::prelude::*;
4
5use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
6use ailake_core::{AilakeResult, RowId, VectorMetric};
7use ailake_file::AilakeFileReader;
8use ailake_index::AnyIndex;
9use ailake_store::Store;
10use ailake_vec::exact_distance;
11use bytes::Bytes;
12
13use crate::pruner::VectorPruner;
14
15#[derive(Debug, Clone)]
16pub struct SearchConfig {
17 pub top_k: usize,
18 pub ef_search: usize,
19 pub pruning_threshold: f32,
23 pub rerank_factor: Option<usize>,
28}
29
30impl Default for SearchConfig {
31 fn default() -> Self {
32 Self {
33 top_k: 10,
34 ef_search: 50,
35 pruning_threshold: f32::INFINITY,
36 rerank_factor: None,
37 }
38 }
39}
40
41impl SearchConfig {
42 pub fn with_pruning(mut self, threshold: f32) -> Self {
43 self.pruning_threshold = threshold;
44 self
45 }
46
47 pub fn with_reranking(mut self, factor: usize) -> Self {
48 self.rerank_factor = Some(factor);
49 self
50 }
51}
52
53#[derive(Debug)]
54pub struct SearchResult {
55 pub row_id: RowId,
56 pub distance: f32,
57 pub file_path: String,
58}
59
60pub async fn search(
68 table: &TableIdent,
69 query: &[f32],
70 config: SearchConfig,
71 vector_column: &str,
72 dim: u32,
73 catalog: Arc<dyn CatalogProvider>,
74 store: Arc<dyn Store>,
75) -> AilakeResult<Vec<SearchResult>> {
76 let all_files = catalog.list_files(table, None).await?;
78
79 let table_meta = catalog.load_table(table).await?;
81 let metric = parse_metric(
82 table_meta
83 .properties
84 .get("ailake.vector-metric")
85 .map(String::as_str)
86 .unwrap_or("cosine"),
87 );
88
89 let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
91
92 let candidate_k = match config.rerank_factor {
93 Some(factor) => config.top_k * factor,
94 None => config.top_k,
95 };
96
97 let mut all_results: Vec<SearchResult> = Vec::new();
98
99 for file_entry in &surviving_files {
100 let file_bytes: Bytes = store.get(&file_entry.path).await?;
101 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
102
103 if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
104 let (_, raw_vectors) = reader.read_parquet()?;
106 for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
107 all_results.push(SearchResult {
108 row_id,
109 distance,
110 file_path: file_entry.path.clone(),
111 });
112 }
113 continue;
114 }
115
116 let index = reader.load_any_index_for_column(vector_column)?;
117 let local_results = index.search(query, candidate_k, config.ef_search);
118
119 if config.rerank_factor.is_some() {
120 let (_, raw_vectors) = reader.read_parquet()?;
122 for (row_id, _approx_dist) in local_results {
123 let idx = row_id.as_u64() as usize;
124 let exact_dist = raw_vectors
125 .get(idx)
126 .map(|v| exact_distance(metric, query, v))
127 .unwrap_or(f32::INFINITY);
128 all_results.push(SearchResult {
129 row_id,
130 distance: exact_dist,
131 file_path: file_entry.path.clone(),
132 });
133 }
134 } else {
135 for (row_id, distance) in local_results {
136 all_results.push(SearchResult {
137 row_id,
138 distance,
139 file_path: file_entry.path.clone(),
140 });
141 }
142 }
143 }
144
145 all_results.sort_by(|a, b| {
147 a.distance
148 .partial_cmp(&b.distance)
149 .unwrap_or(std::cmp::Ordering::Equal)
150 });
151 all_results.truncate(config.top_k);
152 Ok(all_results)
153}
154
155fn flat_search(
157 raw: &[Vec<f32>],
158 query: &[f32],
159 top_k: usize,
160 metric: VectorMetric,
161) -> Vec<(RowId, f32)> {
162 let mut results: Vec<(RowId, f32)> = raw
163 .iter()
164 .enumerate()
165 .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
166 .collect();
167 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
168 results.truncate(top_k);
169 results
170}
171
172fn parse_metric(s: &str) -> VectorMetric {
173 match s {
174 "euclidean" => VectorMetric::Euclidean,
175 "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
176 _ => VectorMetric::Cosine,
177 }
178}
179
180pub struct SearchSession {
185 shards: Vec<LoadedShard>,
186 metric: VectorMetric,
187}
188
189struct LoadedShard {
190 entry: DataFileEntry,
191 index: Option<AnyIndex>,
193 raw_vectors: Option<Vec<Vec<f32>>>,
196}
197
198impl SearchSession {
199 pub async fn load(
205 table: &TableIdent,
206 vector_column: &str,
207 dim: u32,
208 catalog: Arc<dyn CatalogProvider>,
209 store: Arc<dyn Store>,
210 load_raw: bool,
211 ) -> AilakeResult<Self> {
212 let all_files = catalog.list_files(table, None).await?;
213 let table_meta = catalog.load_table(table).await?;
214 let metric = parse_metric(
215 table_meta
216 .properties
217 .get("ailake.vector-metric")
218 .map(String::as_str)
219 .unwrap_or("cosine"),
220 );
221
222 let mut shards = Vec::with_capacity(all_files.len());
223 for entry in all_files {
224 let file_bytes: Bytes = store.get(&entry.path).await?;
225 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
226
227 if entry.index_status == IndexStatus::Indexing {
228 let (_, raw_vecs) = reader.read_parquet()?;
230 shards.push(LoadedShard {
231 entry,
232 index: None,
233 raw_vectors: Some(raw_vecs),
234 });
235 } else if reader.is_ailake_file() {
236 let mut index = reader.load_any_index_for_column(vector_column)?;
237 let raw_vectors = if load_raw {
238 index.quantize_to_f16();
239 let (_, vecs) = reader.read_parquet()?;
240 Some(vecs)
241 } else {
242 None
243 };
244 shards.push(LoadedShard {
245 entry,
246 index: Some(index),
247 raw_vectors,
248 });
249 }
250 }
251
252 Ok(Self { shards, metric })
253 }
254
255 pub fn shard_count(&self) -> usize {
257 self.shards.len()
258 }
259
260 pub fn search_batch(
269 &self,
270 queries: &[Vec<f32>],
271 config: &SearchConfig,
272 ) -> Vec<Vec<SearchResult>> {
273 if queries.is_empty() {
274 return vec![];
275 }
276
277 let n_queries = queries.len();
278 let candidate_k = match config.rerank_factor {
279 Some(factor) => config.top_k * factor,
280 None => config.top_k,
281 };
282 let use_nvidia = ailake_index::hardware::detect_cuda();
283 let use_amd = ailake_index::hardware::detect_rocm();
284
285 let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
287
288 for shard in &self.shards {
289 if let Some(raw) = &shard.raw_vectors {
290 if !raw.is_empty() {
292 let dim = raw[0].len();
293 let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
294 let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
295 let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
296
297 let gpu_batch = if use_nvidia {
298 ailake_index::gpu::try_nvidia_search_batch(
299 &q_refs,
300 &row_ids,
301 &flat,
302 dim,
303 self.metric,
304 candidate_k,
305 )
306 } else if use_amd {
307 ailake_index::gpu::try_rocm_search_batch(
308 &q_refs,
309 &row_ids,
310 &flat,
311 dim,
312 self.metric,
313 candidate_k,
314 )
315 } else {
316 None
317 };
318
319 if let Some(batch) = gpu_batch {
320 for (qi, results) in batch.into_iter().enumerate() {
321 for (row_id, distance) in results {
322 all_results[qi].push(SearchResult {
323 row_id,
324 distance,
325 file_path: shard.entry.path.clone(),
326 });
327 }
328 }
329 continue;
330 }
331 }
332
333 for (qi, query) in queries.iter().enumerate() {
335 for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
336 all_results[qi].push(SearchResult {
337 row_id,
338 distance,
339 file_path: shard.entry.path.clone(),
340 });
341 }
342 }
343 } else if let Some(index) = &shard.index {
344 let shard_results: Vec<Vec<SearchResult>> = queries
346 .par_iter()
347 .map(|query| {
348 index
349 .search(query, candidate_k, config.ef_search)
350 .into_iter()
351 .map(|(row_id, distance)| SearchResult {
352 row_id,
353 distance,
354 file_path: shard.entry.path.clone(),
355 })
356 .collect()
357 })
358 .collect();
359
360 for (qi, results) in shard_results.into_iter().enumerate() {
361 all_results[qi].extend(results);
362 }
363 }
364 }
365
366 for results in &mut all_results {
368 results.sort_by(|a, b| {
369 a.distance
370 .partial_cmp(&b.distance)
371 .unwrap_or(std::cmp::Ordering::Equal)
372 });
373 results.truncate(config.top_k);
374 }
375
376 all_results
377 }
378
379 pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
381 let candidate_k = match config.rerank_factor {
382 Some(factor) => config.top_k * factor,
383 None => config.top_k,
384 };
385
386 let mut all_results: Vec<SearchResult> = self
387 .shards
388 .par_iter()
389 .flat_map(|shard| {
390 if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
392 let dist = match self.metric {
393 VectorMetric::Cosine => {
394 ailake_vec::cosine_distance(query, ¢roid.values)
395 }
396 VectorMetric::Euclidean => {
397 ailake_vec::euclidean_distance(query, ¢roid.values)
398 }
399 VectorMetric::DotProduct => {
400 -ailake_vec::dot_product(query, ¢roid.values)
401 }
402 };
403 if dist - centroid.radius > config.pruning_threshold {
404 return vec![];
405 }
406 }
407
408 if let Some(index) = &shard.index {
409 let local_results = index.search(query, candidate_k, config.ef_search);
411 if config.rerank_factor.is_some() {
412 if let Some(raw) = &shard.raw_vectors {
413 local_results
414 .into_iter()
415 .map(|(row_id, _approx_dist)| {
416 let idx = row_id.as_u64() as usize;
417 let exact_dist = raw
418 .get(idx)
419 .map(|v| exact_distance(self.metric, query, v))
420 .unwrap_or(f32::INFINITY);
421 SearchResult {
422 row_id,
423 distance: exact_dist,
424 file_path: shard.entry.path.clone(),
425 }
426 })
427 .collect()
428 } else {
429 local_results
430 .into_iter()
431 .map(|(row_id, distance)| SearchResult {
432 row_id,
433 distance,
434 file_path: shard.entry.path.clone(),
435 })
436 .collect()
437 }
438 } else {
439 local_results
440 .into_iter()
441 .map(|(row_id, distance)| SearchResult {
442 row_id,
443 distance,
444 file_path: shard.entry.path.clone(),
445 })
446 .collect()
447 }
448 } else if let Some(raw) = &shard.raw_vectors {
449 flat_search(raw, query, candidate_k, self.metric)
451 .into_iter()
452 .map(|(row_id, distance)| SearchResult {
453 row_id,
454 distance,
455 file_path: shard.entry.path.clone(),
456 })
457 .collect()
458 } else {
459 vec![]
460 }
461 })
462 .collect();
463
464 all_results.sort_by(|a, b| {
465 a.distance
466 .partial_cmp(&b.distance)
467 .unwrap_or(std::cmp::Ordering::Equal)
468 });
469 all_results.truncate(config.top_k);
470 all_results
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use ailake_catalog::{HadoopCatalog, TableIdent};
478 use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
479 use ailake_store::LocalStore;
480 use arrow_array::{Int32Array, RecordBatch};
481 use arrow_schema::{DataType, Field, Schema};
482 use std::sync::Arc;
483 use tempfile::TempDir;
484
485 fn make_policy(dim: u32) -> VectorStoragePolicy {
486 VectorStoragePolicy {
487 column_name: "embedding".to_string(),
488 dim,
489 metric: VectorMetric::Cosine,
490 precision: VectorPrecision::F16,
491 pq: None,
492 keep_raw_for_reranking: false,
493 }
494 }
495
496 async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
497 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
498 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
499 let table = TableIdent::new("default", "table");
500
501 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
502 let ids: Vec<i32> = (0..rows as i32).collect();
503 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
504
505 let embeddings: Vec<Vec<f32>> = (0..rows)
507 .map(|i| {
508 let mut v = vec![0.0f32; dim];
509 v[i % dim] = 1.0;
510 v
511 })
512 .collect();
513
514 let mut writer =
515 crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
516 .await
517 .unwrap();
518 writer.write_batch(&batch, &embeddings).await.unwrap();
519 writer.commit().await.unwrap();
520 }
521
522 #[tokio::test]
523 async fn rerank_returns_correct_top_k_count() {
524 let dir = TempDir::new().unwrap();
525 let dim = 8usize;
526 write_demo_table(&dir, dim, 8).await;
527
528 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
529 let catalog: Arc<dyn CatalogProvider> =
530 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
531 let table = TableIdent::new("default", "table");
532
533 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
534 let config = SearchConfig {
535 top_k: 3,
536 ef_search: 50,
537 pruning_threshold: f32::INFINITY,
538 rerank_factor: Some(2),
539 };
540
541 let results = search(
542 &table,
543 &query,
544 config,
545 "embedding",
546 dim as u32,
547 catalog,
548 store,
549 )
550 .await
551 .unwrap();
552
553 assert_eq!(results.len(), 3);
554 }
555
556 #[tokio::test]
557 async fn rerank_nearest_is_exact_match() {
558 let dir = TempDir::new().unwrap();
559 let dim = 8usize;
560 write_demo_table(&dir, dim, 8).await;
561
562 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
563 let catalog: Arc<dyn CatalogProvider> =
564 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
565 let table = TableIdent::new("default", "table");
566
567 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
569 let config = SearchConfig {
570 top_k: 1,
571 ef_search: 50,
572 pruning_threshold: f32::INFINITY,
573 rerank_factor: Some(4),
574 };
575
576 let results = search(
577 &table,
578 &query,
579 config,
580 "embedding",
581 dim as u32,
582 catalog,
583 store,
584 )
585 .await
586 .unwrap();
587
588 assert_eq!(results.len(), 1);
589 assert!(
591 results[0].distance < 1e-3,
592 "distance was {}",
593 results[0].distance
594 );
595 assert_eq!(results[0].row_id, RowId::new(0));
596 }
597
598 #[tokio::test]
599 async fn no_rerank_matches_default_behavior() {
600 let dir = TempDir::new().unwrap();
601 let dim = 4usize;
602 write_demo_table(&dir, dim, 4).await;
603
604 let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
605 let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
606 let cat_a: Arc<dyn CatalogProvider> =
607 Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
608 let cat_b: Arc<dyn CatalogProvider> =
609 Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
610 let table = TableIdent::new("default", "table");
611
612 let query = vec![1.0f32, 0.0, 0.0, 0.0];
613 let cfg_plain = SearchConfig {
614 top_k: 2,
615 ef_search: 50,
616 pruning_threshold: f32::INFINITY,
617 rerank_factor: None,
618 };
619 let cfg_rerank = SearchConfig {
620 top_k: 2,
621 ef_search: 50,
622 pruning_threshold: f32::INFINITY,
623 rerank_factor: Some(2),
624 };
625
626 let plain = search(
627 &table,
628 &query,
629 cfg_plain,
630 "embedding",
631 dim as u32,
632 cat_a,
633 store_a,
634 )
635 .await
636 .unwrap();
637 let reranked = search(
638 &table,
639 &query,
640 cfg_rerank,
641 "embedding",
642 dim as u32,
643 cat_b,
644 store_b,
645 )
646 .await
647 .unwrap();
648
649 assert_eq!(plain[0].row_id, reranked[0].row_id);
651 }
652}