Skip to main content

ailake_query/
scanner.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::Arc;
3
4use rayon::prelude::*;
5
6use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
7use ailake_core::{AilakeResult, RowId, VectorMetric};
8use ailake_file::AilakeFileReader;
9use ailake_index::AnyIndex;
10use ailake_store::Store;
11use ailake_vec::exact_distance;
12use bytes::Bytes;
13
14use crate::pruner::VectorPruner;
15
16#[derive(Debug, Clone)]
17pub struct SearchConfig {
18    pub top_k: usize,
19    pub ef_search: usize,
20    /// Maximum distance from query to file centroid edge for a file to be searched.
21    /// Files where `distance(query, centroid) - radius > pruning_threshold` are skipped.
22    /// Set to `f32::INFINITY` to disable pruning (scan all files).
23    pub pruning_threshold: f32,
24    /// When `Some(factor)`, fetch `top_k * factor` candidates from the HNSW index and
25    /// rerank them using exact F32 distances before truncating to `top_k`.
26    /// Corrects the approximation error introduced by PQ-compressed HNSW distances.
27    /// `None` (default) disables reranking.
28    pub rerank_factor: Option<usize>,
29}
30
31impl Default for SearchConfig {
32    fn default() -> Self {
33        Self {
34            top_k: 10,
35            ef_search: 50,
36            pruning_threshold: f32::INFINITY,
37            rerank_factor: None,
38        }
39    }
40}
41
42impl SearchConfig {
43    pub fn with_pruning(mut self, threshold: f32) -> Self {
44        self.pruning_threshold = threshold;
45        self
46    }
47
48    pub fn with_reranking(mut self, factor: usize) -> Self {
49        self.rerank_factor = Some(factor);
50        self
51    }
52}
53
54#[derive(Debug)]
55pub struct SearchResult {
56    pub row_id: RowId,
57    pub distance: f32,
58    pub file_path: String,
59}
60
61/// Search across all files in the latest snapshot, with geometric pruning.
62///
63/// Flow:
64/// 1. Load file list from catalog (includes centroid metadata)
65/// 2. Prune files whose centroid + radius cannot contain a result within `pruning_threshold`
66/// 3. For surviving files: load bytes, deserialize HNSW, run top-k search
67/// 4. Global merge of all per-file top-k lists, return global top-k
68pub async fn search(
69    table: &TableIdent,
70    query: &[f32],
71    config: SearchConfig,
72    vector_column: &str,
73    dim: u32,
74    catalog: Arc<dyn CatalogProvider>,
75    store: Arc<dyn Store>,
76) -> AilakeResult<Vec<SearchResult>> {
77    // Get file metadata (includes centroid info) without reading any data files
78    let all_files = catalog.list_files(table, None).await?;
79
80    // Determine vector metric from table metadata for correct distance computation
81    let table_meta = catalog.load_table(table).await?;
82    let metric = parse_metric(
83        table_meta
84            .properties
85            .get("ailake.vector-metric")
86            .map(String::as_str)
87            .unwrap_or("cosine"),
88    );
89
90    // Geometric pruning: skip files whose centroid is too far from the query
91    let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
92
93    let candidate_k = match config.rerank_factor {
94        Some(factor) => config.top_k * factor,
95        None => config.top_k,
96    };
97
98    let mut all_results: Vec<SearchResult> = Vec::new();
99
100    for file_entry in &surviving_files {
101        let file_bytes: Bytes = store.get(&file_entry.path).await?;
102        let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
103
104        if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
105            // HNSW not yet built — flat scan over raw vectors.
106            let (_, raw_vectors) = reader.read_parquet()?;
107            for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
108                all_results.push(SearchResult {
109                    row_id,
110                    distance,
111                    file_path: file_entry.path.clone(),
112                });
113            }
114            continue;
115        }
116
117        let index = reader.load_any_index_for_column(vector_column)?;
118        let local_results = index.search(query, candidate_k, config.ef_search);
119
120        if config.rerank_factor.is_some() {
121            // Read raw F32 vectors for exact distance reranking; file bytes already loaded.
122            let (_, raw_vectors) = reader.read_parquet()?;
123            for (row_id, _approx_dist) in local_results {
124                let idx = row_id.as_u64() as usize;
125                let exact_dist = raw_vectors
126                    .get(idx)
127                    .map(|v| exact_distance(metric, query, v))
128                    .unwrap_or(f32::INFINITY);
129                all_results.push(SearchResult {
130                    row_id,
131                    distance: exact_dist,
132                    file_path: file_entry.path.clone(),
133                });
134            }
135        } else {
136            for (row_id, distance) in local_results {
137                all_results.push(SearchResult {
138                    row_id,
139                    distance,
140                    file_path: file_entry.path.clone(),
141                });
142            }
143        }
144    }
145
146    // Global merge: sort all candidates by distance, keep top-k
147    all_results.sort_by(|a, b| {
148        a.distance
149            .partial_cmp(&b.distance)
150            .unwrap_or(std::cmp::Ordering::Equal)
151    });
152    all_results.truncate(config.top_k);
153    Ok(all_results)
154}
155
156/// Brute-force top-k search over raw vectors. Used for Indexing shards.
157fn flat_search(
158    raw: &[Vec<f32>],
159    query: &[f32],
160    top_k: usize,
161    metric: VectorMetric,
162) -> Vec<(RowId, f32)> {
163    let mut results: Vec<(RowId, f32)> = raw
164        .iter()
165        .enumerate()
166        .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
167        .collect();
168    results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
169    results.truncate(top_k);
170    results
171}
172
173fn parse_metric(s: &str) -> VectorMetric {
174    match s {
175        "euclidean" => VectorMetric::Euclidean,
176        "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
177        _ => VectorMetric::Cosine,
178    }
179}
180
181/// Pre-loaded search session: all HNSW indexes loaded into memory once.
182///
183/// Useful for benchmarks and servers that issue many queries against the same
184/// snapshot. Avoids re-loading and re-deserializing indexes on every call.
185pub struct SearchSession {
186    shards: Vec<LoadedShard>,
187    metric: VectorMetric,
188}
189
190struct LoadedShard {
191    entry: DataFileEntry,
192    /// None when the shard is still being indexed (IndexStatus::Indexing).
193    index: Option<AnyIndex>,
194    /// Raw F32 vectors: always present for Indexing shards (flat scan), optionally
195    /// present for Ready shards when `load_raw = true` (reranking).
196    raw_vectors: Option<Vec<Vec<f32>>>,
197}
198
199impl SearchSession {
200    /// Load all indexes for the latest snapshot into memory.
201    ///
202    /// Pass `load_raw = true` when reranking will be used (`rerank_factor` is
203    /// `Some`); it reads the full parquet columns so exact distances are
204    /// available without extra I/O during `search_query`.
205    pub async fn load(
206        table: &TableIdent,
207        vector_column: &str,
208        dim: u32,
209        catalog: Arc<dyn CatalogProvider>,
210        store: Arc<dyn Store>,
211        load_raw: bool,
212    ) -> AilakeResult<Self> {
213        let all_files = catalog.list_files(table, None).await?;
214        let table_meta = catalog.load_table(table).await?;
215        let metric = parse_metric(
216            table_meta
217                .properties
218                .get("ailake.vector-metric")
219                .map(String::as_str)
220                .unwrap_or("cosine"),
221        );
222
223        let mut shards = Vec::with_capacity(all_files.len());
224        for entry in all_files {
225            let file_bytes: Bytes = store.get(&entry.path).await?;
226            let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
227
228            if entry.index_status == IndexStatus::Indexing {
229                // HNSW not yet built — load raw vectors for flat scan.
230                let (_, raw_vecs) = reader.read_parquet()?;
231                shards.push(LoadedShard {
232                    entry,
233                    index: None,
234                    raw_vectors: Some(raw_vecs),
235                });
236            } else if reader.is_ailake_file() {
237                let mut index = reader.load_any_index_for_column(vector_column)?;
238                let raw_vectors = if load_raw {
239                    index.quantize_to_f16();
240                    let (_, vecs) = reader.read_parquet()?;
241                    Some(vecs)
242                } else {
243                    None
244                };
245                shards.push(LoadedShard {
246                    entry,
247                    index: Some(index),
248                    raw_vectors,
249                });
250            }
251        }
252
253        Ok(Self { shards, metric })
254    }
255
256    /// Number of loaded shards.
257    pub fn shard_count(&self) -> usize {
258        self.shards.len()
259    }
260
261    /// Search multiple queries in one call.
262    ///
263    /// For shards with raw vectors (Indexing or reranking): dispatches to GPU batch
264    /// matmul when a CUDA device is available, falling back to CPU flat scan.
265    /// For indexed shards (HNSW / IVF-PQ): rayon parallel-map over queries — graph
266    /// traversal is inherently sequential and has no GPU batch path.
267    ///
268    /// Returns one `Vec<SearchResult>` per input query, in the same order.
269    pub fn search_batch(
270        &self,
271        queries: &[Vec<f32>],
272        config: &SearchConfig,
273    ) -> Vec<Vec<SearchResult>> {
274        if queries.is_empty() {
275            return vec![];
276        }
277
278        let n_queries = queries.len();
279        let candidate_k = match config.rerank_factor {
280            Some(factor) => config.top_k * factor,
281            None => config.top_k,
282        };
283        let use_nvidia = ailake_index::hardware::detect_cuda();
284        let use_amd = ailake_index::hardware::detect_rocm();
285
286        // Accumulate per-query results across all shards.
287        let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
288
289        for shard in &self.shards {
290            if let Some(raw) = &shard.raw_vectors {
291                // Flat-scan shard — try GPU batch path (NVIDIA first, then AMD ROCm).
292                if !raw.is_empty() {
293                    let dim = raw[0].len();
294                    let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
295                    let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
296                    let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
297
298                    let gpu_batch = if use_nvidia {
299                        ailake_index::gpu::try_nvidia_search_batch(
300                            &q_refs,
301                            &row_ids,
302                            &flat,
303                            dim,
304                            self.metric,
305                            candidate_k,
306                        )
307                    } else if use_amd {
308                        ailake_index::gpu::try_rocm_search_batch(
309                            &q_refs,
310                            &row_ids,
311                            &flat,
312                            dim,
313                            self.metric,
314                            candidate_k,
315                        )
316                    } else {
317                        None
318                    };
319
320                    if let Some(batch) = gpu_batch {
321                        for (qi, results) in batch.into_iter().enumerate() {
322                            for (row_id, distance) in results {
323                                all_results[qi].push(SearchResult {
324                                    row_id,
325                                    distance,
326                                    file_path: shard.entry.path.clone(),
327                                });
328                            }
329                        }
330                        continue;
331                    }
332                }
333
334                // CPU fallback for flat scan.
335                for (qi, query) in queries.iter().enumerate() {
336                    for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
337                        all_results[qi].push(SearchResult {
338                            row_id,
339                            distance,
340                            file_path: shard.entry.path.clone(),
341                        });
342                    }
343                }
344            } else if let Some(index) = &shard.index {
345                // Indexed shard — rayon parallel-map over queries.
346                let shard_results: Vec<Vec<SearchResult>> = queries
347                    .par_iter()
348                    .map(|query| {
349                        index
350                            .search(query, candidate_k, config.ef_search)
351                            .into_iter()
352                            .map(|(row_id, distance)| SearchResult {
353                                row_id,
354                                distance,
355                                file_path: shard.entry.path.clone(),
356                            })
357                            .collect()
358                    })
359                    .collect();
360
361                for (qi, results) in shard_results.into_iter().enumerate() {
362                    all_results[qi].extend(results);
363                }
364            }
365        }
366
367        // Sort + truncate per query.
368        for results in &mut all_results {
369            results.sort_by(|a, b| {
370                a.distance
371                    .partial_cmp(&b.distance)
372                    .unwrap_or(std::cmp::Ordering::Equal)
373            });
374            results.truncate(config.top_k);
375        }
376
377        all_results
378    }
379
380    /// Search using pre-loaded indexes. No I/O — pure in-memory search.
381    pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
382        let candidate_k = match config.rerank_factor {
383            Some(factor) => config.top_k * factor,
384            None => config.top_k,
385        };
386
387        let mut all_results: Vec<SearchResult> = self
388            .shards
389            .par_iter()
390            .flat_map(|shard| {
391                // Geometric pruning per shard.
392                if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
393                    let dist = match self.metric {
394                        VectorMetric::Cosine => {
395                            ailake_vec::cosine_distance(query, &centroid.values)
396                        }
397                        VectorMetric::Euclidean => {
398                            ailake_vec::euclidean_distance(query, &centroid.values)
399                        }
400                        VectorMetric::DotProduct => {
401                            -ailake_vec::dot_product(query, &centroid.values)
402                        }
403                    };
404                    if dist - centroid.radius > config.pruning_threshold {
405                        return vec![];
406                    }
407                }
408
409                if let Some(index) = &shard.index {
410                    // Ready shard: HNSW or IVF-PQ search (dispatched by AnyIndex).
411                    let local_results = index.search(query, candidate_k, config.ef_search);
412                    if config.rerank_factor.is_some() {
413                        if let Some(raw) = &shard.raw_vectors {
414                            local_results
415                                .into_iter()
416                                .map(|(row_id, _approx_dist)| {
417                                    let idx = row_id.as_u64() as usize;
418                                    let exact_dist = raw
419                                        .get(idx)
420                                        .map(|v| exact_distance(self.metric, query, v))
421                                        .unwrap_or(f32::INFINITY);
422                                    SearchResult {
423                                        row_id,
424                                        distance: exact_dist,
425                                        file_path: shard.entry.path.clone(),
426                                    }
427                                })
428                                .collect()
429                        } else {
430                            local_results
431                                .into_iter()
432                                .map(|(row_id, distance)| SearchResult {
433                                    row_id,
434                                    distance,
435                                    file_path: shard.entry.path.clone(),
436                                })
437                                .collect()
438                        }
439                    } else {
440                        local_results
441                            .into_iter()
442                            .map(|(row_id, distance)| SearchResult {
443                                row_id,
444                                distance,
445                                file_path: shard.entry.path.clone(),
446                            })
447                            .collect()
448                    }
449                } else if let Some(raw) = &shard.raw_vectors {
450                    // Indexing shard: exact flat scan.
451                    flat_search(raw, query, candidate_k, self.metric)
452                        .into_iter()
453                        .map(|(row_id, distance)| SearchResult {
454                            row_id,
455                            distance,
456                            file_path: shard.entry.path.clone(),
457                        })
458                        .collect()
459                } else {
460                    vec![]
461                }
462            })
463            .collect();
464
465        all_results.sort_by(|a, b| {
466            a.distance
467                .partial_cmp(&b.distance)
468                .unwrap_or(std::cmp::Ordering::Equal)
469        });
470        all_results.truncate(config.top_k);
471        all_results
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use ailake_catalog::{HadoopCatalog, TableIdent};
479    use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
480    use ailake_store::LocalStore;
481    use arrow_array::{Int32Array, RecordBatch};
482    use arrow_schema::{DataType, Field, Schema};
483    use std::sync::Arc;
484    use tempfile::TempDir;
485
486    fn make_policy(dim: u32) -> VectorStoragePolicy {
487        VectorStoragePolicy {
488            column_name: "embedding".to_string(),
489            dim,
490            metric: VectorMetric::Cosine,
491            precision: VectorPrecision::F16,
492            pq: None,
493            keep_raw_for_reranking: false,
494        }
495    }
496
497    async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
498        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
499        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
500        let table = TableIdent::new("default", "table");
501
502        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
503        let ids: Vec<i32> = (0..rows as i32).collect();
504        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
505
506        // Each row i has embedding with 1.0 at dimension i and 0 elsewhere (unit basis vectors)
507        let embeddings: Vec<Vec<f32>> = (0..rows)
508            .map(|i| {
509                let mut v = vec![0.0f32; dim];
510                v[i % dim] = 1.0;
511                v
512            })
513            .collect();
514
515        let mut writer =
516            crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
517                .await
518                .unwrap();
519        writer.write_batch(&batch, &embeddings).await.unwrap();
520        writer.commit().await.unwrap();
521    }
522
523    #[tokio::test]
524    async fn rerank_returns_correct_top_k_count() {
525        let dir = TempDir::new().unwrap();
526        let dim = 8usize;
527        write_demo_table(&dir, dim, 8).await;
528
529        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
530        let catalog: Arc<dyn CatalogProvider> =
531            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
532        let table = TableIdent::new("default", "table");
533
534        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
535        let config = SearchConfig {
536            top_k: 3,
537            ef_search: 50,
538            pruning_threshold: f32::INFINITY,
539            rerank_factor: Some(2),
540        };
541
542        let results = search(
543            &table,
544            &query,
545            config,
546            "embedding",
547            dim as u32,
548            catalog,
549            store,
550        )
551        .await
552        .unwrap();
553
554        assert_eq!(results.len(), 3);
555    }
556
557    #[tokio::test]
558    async fn rerank_nearest_is_exact_match() {
559        let dir = TempDir::new().unwrap();
560        let dim = 8usize;
561        write_demo_table(&dir, dim, 8).await;
562
563        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
564        let catalog: Arc<dyn CatalogProvider> =
565            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
566        let table = TableIdent::new("default", "table");
567
568        // Row 0 has [1,0,0,...] — cosine distance to same query is 0
569        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
570        let config = SearchConfig {
571            top_k: 1,
572            ef_search: 50,
573            pruning_threshold: f32::INFINITY,
574            rerank_factor: Some(4),
575        };
576
577        let results = search(
578            &table,
579            &query,
580            config,
581            "embedding",
582            dim as u32,
583            catalog,
584            store,
585        )
586        .await
587        .unwrap();
588
589        assert_eq!(results.len(), 1);
590        // Exact cosine distance between identical unit vectors is ~0 (F16 rounding allowed)
591        assert!(
592            results[0].distance < 1e-3,
593            "distance was {}",
594            results[0].distance
595        );
596        assert_eq!(results[0].row_id, RowId::new(0));
597    }
598
599    #[tokio::test]
600    async fn no_rerank_matches_default_behavior() {
601        let dir = TempDir::new().unwrap();
602        let dim = 4usize;
603        write_demo_table(&dir, dim, 4).await;
604
605        let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
606        let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
607        let cat_a: Arc<dyn CatalogProvider> =
608            Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
609        let cat_b: Arc<dyn CatalogProvider> =
610            Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
611        let table = TableIdent::new("default", "table");
612
613        let query = vec![1.0f32, 0.0, 0.0, 0.0];
614        let cfg_plain = SearchConfig {
615            top_k: 2,
616            ef_search: 50,
617            pruning_threshold: f32::INFINITY,
618            rerank_factor: None,
619        };
620        let cfg_rerank = SearchConfig {
621            top_k: 2,
622            ef_search: 50,
623            pruning_threshold: f32::INFINITY,
624            rerank_factor: Some(2),
625        };
626
627        let plain = search(
628            &table,
629            &query,
630            cfg_plain,
631            "embedding",
632            dim as u32,
633            cat_a,
634            store_a,
635        )
636        .await
637        .unwrap();
638        let reranked = search(
639            &table,
640            &query,
641            cfg_rerank,
642            "embedding",
643            dim as u32,
644            cat_b,
645            store_b,
646        )
647        .await
648        .unwrap();
649
650        // Both should return same top-1 result (row 0, distance ~0)
651        assert_eq!(plain[0].row_id, reranked[0].row_id);
652    }
653}