Skip to main content

ailake_query/
scanner.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::Arc;
3
4use rayon::prelude::*;
5use tracing::{debug, error};
6
7use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
8use ailake_core::{AilakeResult, RowId, VectorMetric};
9use ailake_file::AilakeFileReader;
10use ailake_index::AnyIndex;
11use ailake_store::Store;
12use ailake_vec::exact_distance;
13use bytes::Bytes;
14
15use crate::pruner::VectorPruner;
16
17#[derive(Debug, Clone)]
18pub struct SearchConfig {
19    pub top_k: usize,
20    pub ef_search: usize,
21    /// Maximum distance from query to file centroid edge for a file to be searched.
22    /// Files where `distance(query, centroid) - radius > pruning_threshold` are skipped.
23    /// Set to `f32::INFINITY` to disable pruning (scan all files).
24    pub pruning_threshold: f32,
25    /// When `Some(factor)`, fetch `top_k * factor` candidates from the HNSW index and
26    /// rerank them using exact F32 distances before truncating to `top_k`.
27    /// Corrects the approximation error introduced by PQ-compressed HNSW distances.
28    /// `None` (default) disables reranking.
29    pub rerank_factor: Option<usize>,
30}
31
32impl Default for SearchConfig {
33    fn default() -> Self {
34        Self {
35            top_k: 10,
36            ef_search: 50,
37            pruning_threshold: f32::INFINITY,
38            rerank_factor: None,
39        }
40    }
41}
42
43impl SearchConfig {
44    pub fn with_pruning(mut self, threshold: f32) -> Self {
45        self.pruning_threshold = threshold;
46        self
47    }
48
49    pub fn with_reranking(mut self, factor: usize) -> Self {
50        self.rerank_factor = Some(factor);
51        self
52    }
53}
54
55#[derive(Debug)]
56pub struct SearchResult {
57    pub row_id: RowId,
58    pub distance: f32,
59    pub file_path: String,
60}
61
62/// Search across all files in the latest snapshot, with geometric pruning.
63///
64/// Flow:
65/// 1. Load file list from catalog (includes centroid metadata)
66/// 2. Prune files whose centroid + radius cannot contain a result within `pruning_threshold`
67/// 3. For surviving files: load bytes, deserialize HNSW, run top-k search
68/// 4. Global merge of all per-file top-k lists, return global top-k
69pub async fn search(
70    table: &TableIdent,
71    query: &[f32],
72    config: SearchConfig,
73    vector_column: &str,
74    dim: u32,
75    catalog: Arc<dyn CatalogProvider>,
76    store: Arc<dyn Store>,
77) -> AilakeResult<Vec<SearchResult>> {
78    // Get file metadata (includes centroid info) without reading any data files
79    let all_files = catalog.list_files(table, None).await?;
80
81    // Determine vector metric from table metadata for correct distance computation
82    let table_meta = catalog.load_table(table).await?;
83    let metric = parse_metric(
84        table_meta
85            .properties
86            .get("ailake.vector-metric")
87            .map(String::as_str)
88            .unwrap_or("cosine"),
89    );
90
91    // Geometric pruning: skip files whose centroid is too far from the query
92    let total_files = all_files.len();
93    let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
94    debug!(
95        "ailake: geometric pruning — {}/{} files survive (threshold={})",
96        surviving_files.len(),
97        total_files,
98        config.pruning_threshold
99    );
100
101    let candidate_k = match config.rerank_factor {
102        Some(factor) => config.top_k * factor,
103        None => config.top_k,
104    };
105
106    let mut all_results: Vec<SearchResult> = Vec::new();
107
108    for file_entry in &surviving_files {
109        let file_bytes: Bytes = store.get(&file_entry.path).await?;
110        let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
111
112        if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
113            // HNSW not yet built — flat scan over raw vectors.
114            debug!(
115                "ailake: flat scan fallback for {} (index_status={:?})",
116                file_entry.path, file_entry.index_status
117            );
118            let (_, raw_vectors) = reader.read_parquet()?;
119            for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
120                all_results.push(SearchResult {
121                    row_id,
122                    distance,
123                    file_path: file_entry.path.clone(),
124                });
125            }
126            continue;
127        }
128
129        let index = reader.load_any_index_for_column(vector_column)?;
130        let local_results = index.search(query, candidate_k, config.ef_search);
131
132        if config.rerank_factor.is_some() {
133            // Read raw F32 vectors for exact distance reranking; file bytes already loaded.
134            let (_, raw_vectors) = reader.read_parquet()?;
135            for (row_id, _approx_dist) in local_results {
136                let idx = row_id.as_u64() as usize;
137                let exact_dist = match raw_vectors.get(idx) {
138                    Some(v) => exact_distance(metric, query, v),
139                    None => {
140                        error!(
141                            "ailake: invariant violated — row_id {} out of bounds \
142                             (raw_vectors.len={}, file={}); \
143                             Parquet row count and HNSW node count are out of sync; \
144                             file may be corrupt — run compaction to rebuild",
145                            idx,
146                            raw_vectors.len(),
147                            file_entry.path
148                        );
149                        f32::INFINITY
150                    }
151                };
152                all_results.push(SearchResult {
153                    row_id,
154                    distance: exact_dist,
155                    file_path: file_entry.path.clone(),
156                });
157            }
158        } else {
159            for (row_id, distance) in local_results {
160                all_results.push(SearchResult {
161                    row_id,
162                    distance,
163                    file_path: file_entry.path.clone(),
164                });
165            }
166        }
167    }
168
169    // Global merge: sort all candidates by distance, keep top-k
170    all_results.sort_by(|a, b| {
171        a.distance
172            .partial_cmp(&b.distance)
173            .unwrap_or(std::cmp::Ordering::Equal)
174    });
175    all_results.truncate(config.top_k);
176    Ok(all_results)
177}
178
179/// Brute-force top-k search over raw vectors. Used for Indexing shards.
180fn flat_search(
181    raw: &[Vec<f32>],
182    query: &[f32],
183    top_k: usize,
184    metric: VectorMetric,
185) -> Vec<(RowId, f32)> {
186    let mut results: Vec<(RowId, f32)> = raw
187        .iter()
188        .enumerate()
189        .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
190        .collect();
191    results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
192    results.truncate(top_k);
193    results
194}
195
196fn parse_metric(s: &str) -> VectorMetric {
197    match s {
198        "euclidean" => VectorMetric::Euclidean,
199        "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
200        _ => VectorMetric::Cosine,
201    }
202}
203
204/// Pre-loaded search session: all HNSW indexes loaded into memory once.
205///
206/// Useful for benchmarks and servers that issue many queries against the same
207/// snapshot. Avoids re-loading and re-deserializing indexes on every call.
208pub struct SearchSession {
209    shards: Vec<LoadedShard>,
210    metric: VectorMetric,
211}
212
213struct LoadedShard {
214    entry: DataFileEntry,
215    /// None when the shard is still being indexed (IndexStatus::Indexing).
216    index: Option<AnyIndex>,
217    /// Raw F32 vectors: always present for Indexing shards (flat scan), optionally
218    /// present for Ready shards when `load_raw = true` (reranking).
219    raw_vectors: Option<Vec<Vec<f32>>>,
220}
221
222impl SearchSession {
223    /// Load all indexes for the latest snapshot into memory.
224    ///
225    /// Pass `load_raw = true` when reranking will be used (`rerank_factor` is
226    /// `Some`); it reads the full parquet columns so exact distances are
227    /// available without extra I/O during `search_query`.
228    pub async fn load(
229        table: &TableIdent,
230        vector_column: &str,
231        dim: u32,
232        catalog: Arc<dyn CatalogProvider>,
233        store: Arc<dyn Store>,
234        load_raw: bool,
235    ) -> AilakeResult<Self> {
236        let all_files = catalog.list_files(table, None).await?;
237        let table_meta = catalog.load_table(table).await?;
238        let metric = parse_metric(
239            table_meta
240                .properties
241                .get("ailake.vector-metric")
242                .map(String::as_str)
243                .unwrap_or("cosine"),
244        );
245
246        let mut shards = Vec::with_capacity(all_files.len());
247        for entry in all_files {
248            let file_bytes: Bytes = store.get(&entry.path).await?;
249            let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
250
251            if entry.index_status == IndexStatus::Indexing {
252                // HNSW not yet built — load raw vectors for flat scan.
253                let (_, raw_vecs) = reader.read_parquet()?;
254                shards.push(LoadedShard {
255                    entry,
256                    index: None,
257                    raw_vectors: Some(raw_vecs),
258                });
259            } else if reader.is_ailake_file() {
260                let mut index = reader.load_any_index_for_column(vector_column)?;
261                let raw_vectors = if load_raw {
262                    index.quantize_to_f16();
263                    let (_, vecs) = reader.read_parquet()?;
264                    Some(vecs)
265                } else {
266                    None
267                };
268                shards.push(LoadedShard {
269                    entry,
270                    index: Some(index),
271                    raw_vectors,
272                });
273            }
274        }
275
276        Ok(Self { shards, metric })
277    }
278
279    /// Number of loaded shards.
280    pub fn shard_count(&self) -> usize {
281        self.shards.len()
282    }
283
284    /// Search multiple queries in one call.
285    ///
286    /// For shards with raw vectors (Indexing or reranking): dispatches to GPU batch
287    /// matmul when a CUDA device is available, falling back to CPU flat scan.
288    /// For indexed shards (HNSW / IVF-PQ): rayon parallel-map over queries — graph
289    /// traversal is inherently sequential and has no GPU batch path.
290    ///
291    /// Returns one `Vec<SearchResult>` per input query, in the same order.
292    pub fn search_batch(
293        &self,
294        queries: &[Vec<f32>],
295        config: &SearchConfig,
296    ) -> Vec<Vec<SearchResult>> {
297        if queries.is_empty() {
298            return vec![];
299        }
300
301        let n_queries = queries.len();
302        let candidate_k = match config.rerank_factor {
303            Some(factor) => config.top_k * factor,
304            None => config.top_k,
305        };
306        let use_nvidia = ailake_index::hardware::detect_cuda();
307        let use_amd = ailake_index::hardware::detect_rocm();
308
309        // Accumulate per-query results across all shards.
310        let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
311
312        for shard in &self.shards {
313            if let Some(raw) = &shard.raw_vectors {
314                // Flat-scan shard — try GPU batch path (NVIDIA first, then AMD ROCm).
315                if !raw.is_empty() {
316                    let dim = raw[0].len();
317                    let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
318                    let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
319                    let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
320
321                    let gpu_batch = if use_nvidia {
322                        ailake_index::gpu::try_nvidia_search_batch(
323                            &q_refs,
324                            &row_ids,
325                            &flat,
326                            dim,
327                            self.metric,
328                            candidate_k,
329                        )
330                    } else if use_amd {
331                        ailake_index::gpu::try_rocm_search_batch(
332                            &q_refs,
333                            &row_ids,
334                            &flat,
335                            dim,
336                            self.metric,
337                            candidate_k,
338                        )
339                    } else {
340                        None
341                    };
342
343                    if let Some(batch) = gpu_batch {
344                        for (qi, results) in batch.into_iter().enumerate() {
345                            for (row_id, distance) in results {
346                                all_results[qi].push(SearchResult {
347                                    row_id,
348                                    distance,
349                                    file_path: shard.entry.path.clone(),
350                                });
351                            }
352                        }
353                        continue;
354                    }
355                }
356
357                // CPU fallback for flat scan.
358                for (qi, query) in queries.iter().enumerate() {
359                    for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
360                        all_results[qi].push(SearchResult {
361                            row_id,
362                            distance,
363                            file_path: shard.entry.path.clone(),
364                        });
365                    }
366                }
367            } else if let Some(index) = &shard.index {
368                // Indexed shard — rayon parallel-map over queries.
369                let shard_results: Vec<Vec<SearchResult>> = queries
370                    .par_iter()
371                    .map(|query| {
372                        index
373                            .search(query, candidate_k, config.ef_search)
374                            .into_iter()
375                            .map(|(row_id, distance)| SearchResult {
376                                row_id,
377                                distance,
378                                file_path: shard.entry.path.clone(),
379                            })
380                            .collect()
381                    })
382                    .collect();
383
384                for (qi, results) in shard_results.into_iter().enumerate() {
385                    all_results[qi].extend(results);
386                }
387            }
388        }
389
390        // Sort + truncate per query.
391        for results in &mut all_results {
392            results.sort_by(|a, b| {
393                a.distance
394                    .partial_cmp(&b.distance)
395                    .unwrap_or(std::cmp::Ordering::Equal)
396            });
397            results.truncate(config.top_k);
398        }
399
400        all_results
401    }
402
403    /// Search using pre-loaded indexes. No I/O — pure in-memory search.
404    pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
405        let candidate_k = match config.rerank_factor {
406            Some(factor) => config.top_k * factor,
407            None => config.top_k,
408        };
409
410        let mut all_results: Vec<SearchResult> = self
411            .shards
412            .par_iter()
413            .flat_map(|shard| {
414                // Geometric pruning per shard.
415                if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
416                    let dist = match self.metric {
417                        VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
418                            ailake_vec::cosine_distance(query, &centroid.values)
419                        }
420                        VectorMetric::Euclidean => {
421                            ailake_vec::euclidean_distance(query, &centroid.values)
422                        }
423                        VectorMetric::DotProduct => {
424                            -ailake_vec::dot_product(query, &centroid.values)
425                        }
426                    };
427                    if dist - centroid.radius > config.pruning_threshold {
428                        return vec![];
429                    }
430                }
431
432                if let Some(index) = &shard.index {
433                    // Ready shard: HNSW or IVF-PQ search (dispatched by AnyIndex).
434                    let local_results = index.search(query, candidate_k, config.ef_search);
435                    if config.rerank_factor.is_some() {
436                        if let Some(raw) = &shard.raw_vectors {
437                            local_results
438                                .into_iter()
439                                .map(|(row_id, _approx_dist)| {
440                                    let idx = row_id.as_u64() as usize;
441                                    let exact_dist = raw
442                                        .get(idx)
443                                        .map(|v| exact_distance(self.metric, query, v))
444                                        .unwrap_or(f32::INFINITY);
445                                    SearchResult {
446                                        row_id,
447                                        distance: exact_dist,
448                                        file_path: shard.entry.path.clone(),
449                                    }
450                                })
451                                .collect()
452                        } else {
453                            local_results
454                                .into_iter()
455                                .map(|(row_id, distance)| SearchResult {
456                                    row_id,
457                                    distance,
458                                    file_path: shard.entry.path.clone(),
459                                })
460                                .collect()
461                        }
462                    } else {
463                        local_results
464                            .into_iter()
465                            .map(|(row_id, distance)| SearchResult {
466                                row_id,
467                                distance,
468                                file_path: shard.entry.path.clone(),
469                            })
470                            .collect()
471                    }
472                } else if let Some(raw) = &shard.raw_vectors {
473                    // Indexing shard: exact flat scan.
474                    flat_search(raw, query, candidate_k, self.metric)
475                        .into_iter()
476                        .map(|(row_id, distance)| SearchResult {
477                            row_id,
478                            distance,
479                            file_path: shard.entry.path.clone(),
480                        })
481                        .collect()
482                } else {
483                    vec![]
484                }
485            })
486            .collect();
487
488        all_results.sort_by(|a, b| {
489            a.distance
490                .partial_cmp(&b.distance)
491                .unwrap_or(std::cmp::Ordering::Equal)
492        });
493        all_results.truncate(config.top_k);
494        all_results
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use ailake_catalog::{HadoopCatalog, TableIdent};
502    use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
503    use ailake_store::LocalStore;
504    use arrow_array::{Int32Array, RecordBatch};
505    use arrow_schema::{DataType, Field, Schema};
506    use std::sync::Arc;
507    use tempfile::TempDir;
508
509    fn make_policy(dim: u32) -> VectorStoragePolicy {
510        VectorStoragePolicy {
511            column_name: "embedding".to_string(),
512            dim,
513            metric: VectorMetric::Cosine,
514            precision: VectorPrecision::F16,
515            pq: None,
516            keep_raw_for_reranking: false,
517            pre_normalize: false,
518            hnsw_m: None,
519            hnsw_ef_construction: None,
520            rabitq: None,
521        }
522    }
523
524    async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
525        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
526        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
527        let table = TableIdent::new("default", "table");
528
529        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
530        let ids: Vec<i32> = (0..rows as i32).collect();
531        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
532
533        // Each row i has embedding with 1.0 at dimension i and 0 elsewhere (unit basis vectors)
534        let embeddings: Vec<Vec<f32>> = (0..rows)
535            .map(|i| {
536                let mut v = vec![0.0f32; dim];
537                v[i % dim] = 1.0;
538                v
539            })
540            .collect();
541
542        let mut writer =
543            crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
544                .await
545                .unwrap();
546        writer.write_batch(&batch, &embeddings).await.unwrap();
547        writer.commit().await.unwrap();
548    }
549
550    #[tokio::test]
551    async fn rerank_returns_correct_top_k_count() {
552        let dir = TempDir::new().unwrap();
553        let dim = 8usize;
554        write_demo_table(&dir, dim, 8).await;
555
556        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
557        let catalog: Arc<dyn CatalogProvider> =
558            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
559        let table = TableIdent::new("default", "table");
560
561        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
562        let config = SearchConfig {
563            top_k: 3,
564            ef_search: 50,
565            pruning_threshold: f32::INFINITY,
566            rerank_factor: Some(2),
567        };
568
569        let results = search(
570            &table,
571            &query,
572            config,
573            "embedding",
574            dim as u32,
575            catalog,
576            store,
577        )
578        .await
579        .unwrap();
580
581        assert_eq!(results.len(), 3);
582    }
583
584    #[tokio::test]
585    async fn rerank_nearest_is_exact_match() {
586        let dir = TempDir::new().unwrap();
587        let dim = 8usize;
588        write_demo_table(&dir, dim, 8).await;
589
590        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
591        let catalog: Arc<dyn CatalogProvider> =
592            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
593        let table = TableIdent::new("default", "table");
594
595        // Row 0 has [1,0,0,...] — cosine distance to same query is 0
596        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
597        let config = SearchConfig {
598            top_k: 1,
599            ef_search: 50,
600            pruning_threshold: f32::INFINITY,
601            rerank_factor: Some(4),
602        };
603
604        let results = search(
605            &table,
606            &query,
607            config,
608            "embedding",
609            dim as u32,
610            catalog,
611            store,
612        )
613        .await
614        .unwrap();
615
616        assert_eq!(results.len(), 1);
617        // Exact cosine distance between identical unit vectors is ~0 (F16 rounding allowed)
618        assert!(
619            results[0].distance < 1e-3,
620            "distance was {}",
621            results[0].distance
622        );
623        assert_eq!(results[0].row_id, RowId::new(0));
624    }
625
626    #[tokio::test]
627    async fn no_rerank_matches_default_behavior() {
628        let dir = TempDir::new().unwrap();
629        let dim = 4usize;
630        write_demo_table(&dir, dim, 4).await;
631
632        let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
633        let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
634        let cat_a: Arc<dyn CatalogProvider> =
635            Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
636        let cat_b: Arc<dyn CatalogProvider> =
637            Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
638        let table = TableIdent::new("default", "table");
639
640        let query = vec![1.0f32, 0.0, 0.0, 0.0];
641        let cfg_plain = SearchConfig {
642            top_k: 2,
643            ef_search: 50,
644            pruning_threshold: f32::INFINITY,
645            rerank_factor: None,
646        };
647        let cfg_rerank = SearchConfig {
648            top_k: 2,
649            ef_search: 50,
650            pruning_threshold: f32::INFINITY,
651            rerank_factor: Some(2),
652        };
653
654        let plain = search(
655            &table,
656            &query,
657            cfg_plain,
658            "embedding",
659            dim as u32,
660            cat_a,
661            store_a,
662        )
663        .await
664        .unwrap();
665        let reranked = search(
666            &table,
667            &query,
668            cfg_rerank,
669            "embedding",
670            dim as u32,
671            cat_b,
672            store_b,
673        )
674        .await
675        .unwrap();
676
677        // Both should return same top-1 result (row 0, distance ~0)
678        assert_eq!(plain[0].row_id, reranked[0].row_id);
679    }
680}