1use std::sync::Arc;
3
4use rayon::prelude::*;
5
6use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
7use ailake_core::{AilakeResult, RowId, VectorMetric};
8use ailake_file::AilakeFileReader;
9use ailake_index::AnyIndex;
10use ailake_store::Store;
11use ailake_vec::exact_distance;
12use bytes::Bytes;
13
14use crate::pruner::VectorPruner;
15
16#[derive(Debug, Clone)]
17pub struct SearchConfig {
18 pub top_k: usize,
19 pub ef_search: usize,
20 pub pruning_threshold: f32,
24 pub rerank_factor: Option<usize>,
29}
30
31impl Default for SearchConfig {
32 fn default() -> Self {
33 Self {
34 top_k: 10,
35 ef_search: 50,
36 pruning_threshold: f32::INFINITY,
37 rerank_factor: None,
38 }
39 }
40}
41
42impl SearchConfig {
43 pub fn with_pruning(mut self, threshold: f32) -> Self {
44 self.pruning_threshold = threshold;
45 self
46 }
47
48 pub fn with_reranking(mut self, factor: usize) -> Self {
49 self.rerank_factor = Some(factor);
50 self
51 }
52}
53
54#[derive(Debug)]
55pub struct SearchResult {
56 pub row_id: RowId,
57 pub distance: f32,
58 pub file_path: String,
59}
60
61pub async fn search(
69 table: &TableIdent,
70 query: &[f32],
71 config: SearchConfig,
72 vector_column: &str,
73 dim: u32,
74 catalog: Arc<dyn CatalogProvider>,
75 store: Arc<dyn Store>,
76) -> AilakeResult<Vec<SearchResult>> {
77 let all_files = catalog.list_files(table, None).await?;
79
80 let table_meta = catalog.load_table(table).await?;
82 let metric = parse_metric(
83 table_meta
84 .properties
85 .get("ailake.vector-metric")
86 .map(String::as_str)
87 .unwrap_or("cosine"),
88 );
89
90 let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
92
93 let candidate_k = match config.rerank_factor {
94 Some(factor) => config.top_k * factor,
95 None => config.top_k,
96 };
97
98 let mut all_results: Vec<SearchResult> = Vec::new();
99
100 for file_entry in &surviving_files {
101 let file_bytes: Bytes = store.get(&file_entry.path).await?;
102 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
103
104 if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
105 let (_, raw_vectors) = reader.read_parquet()?;
107 for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
108 all_results.push(SearchResult {
109 row_id,
110 distance,
111 file_path: file_entry.path.clone(),
112 });
113 }
114 continue;
115 }
116
117 let index = reader.load_any_index_for_column(vector_column)?;
118 let local_results = index.search(query, candidate_k, config.ef_search);
119
120 if config.rerank_factor.is_some() {
121 let (_, raw_vectors) = reader.read_parquet()?;
123 for (row_id, _approx_dist) in local_results {
124 let idx = row_id.as_u64() as usize;
125 let exact_dist = raw_vectors
126 .get(idx)
127 .map(|v| exact_distance(metric, query, v))
128 .unwrap_or(f32::INFINITY);
129 all_results.push(SearchResult {
130 row_id,
131 distance: exact_dist,
132 file_path: file_entry.path.clone(),
133 });
134 }
135 } else {
136 for (row_id, distance) in local_results {
137 all_results.push(SearchResult {
138 row_id,
139 distance,
140 file_path: file_entry.path.clone(),
141 });
142 }
143 }
144 }
145
146 all_results.sort_by(|a, b| {
148 a.distance
149 .partial_cmp(&b.distance)
150 .unwrap_or(std::cmp::Ordering::Equal)
151 });
152 all_results.truncate(config.top_k);
153 Ok(all_results)
154}
155
156fn flat_search(
158 raw: &[Vec<f32>],
159 query: &[f32],
160 top_k: usize,
161 metric: VectorMetric,
162) -> Vec<(RowId, f32)> {
163 let mut results: Vec<(RowId, f32)> = raw
164 .iter()
165 .enumerate()
166 .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
167 .collect();
168 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
169 results.truncate(top_k);
170 results
171}
172
173fn parse_metric(s: &str) -> VectorMetric {
174 match s {
175 "euclidean" => VectorMetric::Euclidean,
176 "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
177 _ => VectorMetric::Cosine,
178 }
179}
180
181pub struct SearchSession {
186 shards: Vec<LoadedShard>,
187 metric: VectorMetric,
188}
189
190struct LoadedShard {
191 entry: DataFileEntry,
192 index: Option<AnyIndex>,
194 raw_vectors: Option<Vec<Vec<f32>>>,
197}
198
199impl SearchSession {
200 pub async fn load(
206 table: &TableIdent,
207 vector_column: &str,
208 dim: u32,
209 catalog: Arc<dyn CatalogProvider>,
210 store: Arc<dyn Store>,
211 load_raw: bool,
212 ) -> AilakeResult<Self> {
213 let all_files = catalog.list_files(table, None).await?;
214 let table_meta = catalog.load_table(table).await?;
215 let metric = parse_metric(
216 table_meta
217 .properties
218 .get("ailake.vector-metric")
219 .map(String::as_str)
220 .unwrap_or("cosine"),
221 );
222
223 let mut shards = Vec::with_capacity(all_files.len());
224 for entry in all_files {
225 let file_bytes: Bytes = store.get(&entry.path).await?;
226 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
227
228 if entry.index_status == IndexStatus::Indexing {
229 let (_, raw_vecs) = reader.read_parquet()?;
231 shards.push(LoadedShard {
232 entry,
233 index: None,
234 raw_vectors: Some(raw_vecs),
235 });
236 } else if reader.is_ailake_file() {
237 let mut index = reader.load_any_index_for_column(vector_column)?;
238 let raw_vectors = if load_raw {
239 index.quantize_to_f16();
240 let (_, vecs) = reader.read_parquet()?;
241 Some(vecs)
242 } else {
243 None
244 };
245 shards.push(LoadedShard {
246 entry,
247 index: Some(index),
248 raw_vectors,
249 });
250 }
251 }
252
253 Ok(Self { shards, metric })
254 }
255
256 pub fn shard_count(&self) -> usize {
258 self.shards.len()
259 }
260
261 pub fn search_batch(
270 &self,
271 queries: &[Vec<f32>],
272 config: &SearchConfig,
273 ) -> Vec<Vec<SearchResult>> {
274 if queries.is_empty() {
275 return vec![];
276 }
277
278 let n_queries = queries.len();
279 let candidate_k = match config.rerank_factor {
280 Some(factor) => config.top_k * factor,
281 None => config.top_k,
282 };
283 let use_nvidia = ailake_index::hardware::detect_cuda();
284 let use_amd = ailake_index::hardware::detect_rocm();
285
286 let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
288
289 for shard in &self.shards {
290 if let Some(raw) = &shard.raw_vectors {
291 if !raw.is_empty() {
293 let dim = raw[0].len();
294 let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
295 let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
296 let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
297
298 let gpu_batch = if use_nvidia {
299 ailake_index::gpu::try_nvidia_search_batch(
300 &q_refs,
301 &row_ids,
302 &flat,
303 dim,
304 self.metric,
305 candidate_k,
306 )
307 } else if use_amd {
308 ailake_index::gpu::try_rocm_search_batch(
309 &q_refs,
310 &row_ids,
311 &flat,
312 dim,
313 self.metric,
314 candidate_k,
315 )
316 } else {
317 None
318 };
319
320 if let Some(batch) = gpu_batch {
321 for (qi, results) in batch.into_iter().enumerate() {
322 for (row_id, distance) in results {
323 all_results[qi].push(SearchResult {
324 row_id,
325 distance,
326 file_path: shard.entry.path.clone(),
327 });
328 }
329 }
330 continue;
331 }
332 }
333
334 for (qi, query) in queries.iter().enumerate() {
336 for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
337 all_results[qi].push(SearchResult {
338 row_id,
339 distance,
340 file_path: shard.entry.path.clone(),
341 });
342 }
343 }
344 } else if let Some(index) = &shard.index {
345 let shard_results: Vec<Vec<SearchResult>> = queries
347 .par_iter()
348 .map(|query| {
349 index
350 .search(query, candidate_k, config.ef_search)
351 .into_iter()
352 .map(|(row_id, distance)| SearchResult {
353 row_id,
354 distance,
355 file_path: shard.entry.path.clone(),
356 })
357 .collect()
358 })
359 .collect();
360
361 for (qi, results) in shard_results.into_iter().enumerate() {
362 all_results[qi].extend(results);
363 }
364 }
365 }
366
367 for results in &mut all_results {
369 results.sort_by(|a, b| {
370 a.distance
371 .partial_cmp(&b.distance)
372 .unwrap_or(std::cmp::Ordering::Equal)
373 });
374 results.truncate(config.top_k);
375 }
376
377 all_results
378 }
379
380 pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
382 let candidate_k = match config.rerank_factor {
383 Some(factor) => config.top_k * factor,
384 None => config.top_k,
385 };
386
387 let mut all_results: Vec<SearchResult> = self
388 .shards
389 .par_iter()
390 .flat_map(|shard| {
391 if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
393 let dist = match self.metric {
394 VectorMetric::Cosine => {
395 ailake_vec::cosine_distance(query, ¢roid.values)
396 }
397 VectorMetric::Euclidean => {
398 ailake_vec::euclidean_distance(query, ¢roid.values)
399 }
400 VectorMetric::DotProduct => {
401 -ailake_vec::dot_product(query, ¢roid.values)
402 }
403 };
404 if dist - centroid.radius > config.pruning_threshold {
405 return vec![];
406 }
407 }
408
409 if let Some(index) = &shard.index {
410 let local_results = index.search(query, candidate_k, config.ef_search);
412 if config.rerank_factor.is_some() {
413 if let Some(raw) = &shard.raw_vectors {
414 local_results
415 .into_iter()
416 .map(|(row_id, _approx_dist)| {
417 let idx = row_id.as_u64() as usize;
418 let exact_dist = raw
419 .get(idx)
420 .map(|v| exact_distance(self.metric, query, v))
421 .unwrap_or(f32::INFINITY);
422 SearchResult {
423 row_id,
424 distance: exact_dist,
425 file_path: shard.entry.path.clone(),
426 }
427 })
428 .collect()
429 } else {
430 local_results
431 .into_iter()
432 .map(|(row_id, distance)| SearchResult {
433 row_id,
434 distance,
435 file_path: shard.entry.path.clone(),
436 })
437 .collect()
438 }
439 } else {
440 local_results
441 .into_iter()
442 .map(|(row_id, distance)| SearchResult {
443 row_id,
444 distance,
445 file_path: shard.entry.path.clone(),
446 })
447 .collect()
448 }
449 } else if let Some(raw) = &shard.raw_vectors {
450 flat_search(raw, query, candidate_k, self.metric)
452 .into_iter()
453 .map(|(row_id, distance)| SearchResult {
454 row_id,
455 distance,
456 file_path: shard.entry.path.clone(),
457 })
458 .collect()
459 } else {
460 vec![]
461 }
462 })
463 .collect();
464
465 all_results.sort_by(|a, b| {
466 a.distance
467 .partial_cmp(&b.distance)
468 .unwrap_or(std::cmp::Ordering::Equal)
469 });
470 all_results.truncate(config.top_k);
471 all_results
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use ailake_catalog::{HadoopCatalog, TableIdent};
479 use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
480 use ailake_store::LocalStore;
481 use arrow_array::{Int32Array, RecordBatch};
482 use arrow_schema::{DataType, Field, Schema};
483 use std::sync::Arc;
484 use tempfile::TempDir;
485
486 fn make_policy(dim: u32) -> VectorStoragePolicy {
487 VectorStoragePolicy {
488 column_name: "embedding".to_string(),
489 dim,
490 metric: VectorMetric::Cosine,
491 precision: VectorPrecision::F16,
492 pq: None,
493 keep_raw_for_reranking: false,
494 }
495 }
496
497 async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
498 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
499 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
500 let table = TableIdent::new("default", "table");
501
502 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
503 let ids: Vec<i32> = (0..rows as i32).collect();
504 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
505
506 let embeddings: Vec<Vec<f32>> = (0..rows)
508 .map(|i| {
509 let mut v = vec![0.0f32; dim];
510 v[i % dim] = 1.0;
511 v
512 })
513 .collect();
514
515 let mut writer =
516 crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
517 .await
518 .unwrap();
519 writer.write_batch(&batch, &embeddings).await.unwrap();
520 writer.commit().await.unwrap();
521 }
522
523 #[tokio::test]
524 async fn rerank_returns_correct_top_k_count() {
525 let dir = TempDir::new().unwrap();
526 let dim = 8usize;
527 write_demo_table(&dir, dim, 8).await;
528
529 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
530 let catalog: Arc<dyn CatalogProvider> =
531 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
532 let table = TableIdent::new("default", "table");
533
534 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
535 let config = SearchConfig {
536 top_k: 3,
537 ef_search: 50,
538 pruning_threshold: f32::INFINITY,
539 rerank_factor: Some(2),
540 };
541
542 let results = search(
543 &table,
544 &query,
545 config,
546 "embedding",
547 dim as u32,
548 catalog,
549 store,
550 )
551 .await
552 .unwrap();
553
554 assert_eq!(results.len(), 3);
555 }
556
557 #[tokio::test]
558 async fn rerank_nearest_is_exact_match() {
559 let dir = TempDir::new().unwrap();
560 let dim = 8usize;
561 write_demo_table(&dir, dim, 8).await;
562
563 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
564 let catalog: Arc<dyn CatalogProvider> =
565 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
566 let table = TableIdent::new("default", "table");
567
568 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
570 let config = SearchConfig {
571 top_k: 1,
572 ef_search: 50,
573 pruning_threshold: f32::INFINITY,
574 rerank_factor: Some(4),
575 };
576
577 let results = search(
578 &table,
579 &query,
580 config,
581 "embedding",
582 dim as u32,
583 catalog,
584 store,
585 )
586 .await
587 .unwrap();
588
589 assert_eq!(results.len(), 1);
590 assert!(
592 results[0].distance < 1e-3,
593 "distance was {}",
594 results[0].distance
595 );
596 assert_eq!(results[0].row_id, RowId::new(0));
597 }
598
599 #[tokio::test]
600 async fn no_rerank_matches_default_behavior() {
601 let dir = TempDir::new().unwrap();
602 let dim = 4usize;
603 write_demo_table(&dir, dim, 4).await;
604
605 let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
606 let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
607 let cat_a: Arc<dyn CatalogProvider> =
608 Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
609 let cat_b: Arc<dyn CatalogProvider> =
610 Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
611 let table = TableIdent::new("default", "table");
612
613 let query = vec![1.0f32, 0.0, 0.0, 0.0];
614 let cfg_plain = SearchConfig {
615 top_k: 2,
616 ef_search: 50,
617 pruning_threshold: f32::INFINITY,
618 rerank_factor: None,
619 };
620 let cfg_rerank = SearchConfig {
621 top_k: 2,
622 ef_search: 50,
623 pruning_threshold: f32::INFINITY,
624 rerank_factor: Some(2),
625 };
626
627 let plain = search(
628 &table,
629 &query,
630 cfg_plain,
631 "embedding",
632 dim as u32,
633 cat_a,
634 store_a,
635 )
636 .await
637 .unwrap();
638 let reranked = search(
639 &table,
640 &query,
641 cfg_rerank,
642 "embedding",
643 dim as u32,
644 cat_b,
645 store_b,
646 )
647 .await
648 .unwrap();
649
650 assert_eq!(plain[0].row_id, reranked[0].row_id);
652 }
653}