Skip to main content

ailake_query/
scanner.rs

1use std::sync::Arc;
2
3use rayon::prelude::*;
4
5use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
6use ailake_core::{AilakeResult, RowId, VectorMetric};
7use ailake_file::AilakeFileReader;
8use ailake_index::AnyIndex;
9use ailake_store::Store;
10use ailake_vec::exact_distance;
11use bytes::Bytes;
12
13use crate::pruner::VectorPruner;
14
15#[derive(Debug, Clone)]
16pub struct SearchConfig {
17    pub top_k: usize,
18    pub ef_search: usize,
19    /// Maximum distance from query to file centroid edge for a file to be searched.
20    /// Files where `distance(query, centroid) - radius > pruning_threshold` are skipped.
21    /// Set to `f32::INFINITY` to disable pruning (scan all files).
22    pub pruning_threshold: f32,
23    /// When `Some(factor)`, fetch `top_k * factor` candidates from the HNSW index and
24    /// rerank them using exact F32 distances before truncating to `top_k`.
25    /// Corrects the approximation error introduced by PQ-compressed HNSW distances.
26    /// `None` (default) disables reranking.
27    pub rerank_factor: Option<usize>,
28}
29
30impl Default for SearchConfig {
31    fn default() -> Self {
32        Self {
33            top_k: 10,
34            ef_search: 50,
35            pruning_threshold: f32::INFINITY,
36            rerank_factor: None,
37        }
38    }
39}
40
41impl SearchConfig {
42    pub fn with_pruning(mut self, threshold: f32) -> Self {
43        self.pruning_threshold = threshold;
44        self
45    }
46
47    pub fn with_reranking(mut self, factor: usize) -> Self {
48        self.rerank_factor = Some(factor);
49        self
50    }
51}
52
53#[derive(Debug)]
54pub struct SearchResult {
55    pub row_id: RowId,
56    pub distance: f32,
57    pub file_path: String,
58}
59
60/// Search across all files in the latest snapshot, with geometric pruning.
61///
62/// Flow:
63/// 1. Load file list from catalog (includes centroid metadata)
64/// 2. Prune files whose centroid + radius cannot contain a result within `pruning_threshold`
65/// 3. For surviving files: load bytes, deserialize HNSW, run top-k search
66/// 4. Global merge of all per-file top-k lists, return global top-k
67pub async fn search(
68    table: &TableIdent,
69    query: &[f32],
70    config: SearchConfig,
71    vector_column: &str,
72    dim: u32,
73    catalog: Arc<dyn CatalogProvider>,
74    store: Arc<dyn Store>,
75) -> AilakeResult<Vec<SearchResult>> {
76    // Get file metadata (includes centroid info) without reading any data files
77    let all_files = catalog.list_files(table, None).await?;
78
79    // Determine vector metric from table metadata for correct distance computation
80    let table_meta = catalog.load_table(table).await?;
81    let metric = parse_metric(
82        table_meta
83            .properties
84            .get("ailake.vector-metric")
85            .map(String::as_str)
86            .unwrap_or("cosine"),
87    );
88
89    // Geometric pruning: skip files whose centroid is too far from the query
90    let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
91
92    let candidate_k = match config.rerank_factor {
93        Some(factor) => config.top_k * factor,
94        None => config.top_k,
95    };
96
97    let mut all_results: Vec<SearchResult> = Vec::new();
98
99    for file_entry in &surviving_files {
100        let file_bytes: Bytes = store.get(&file_entry.path).await?;
101        let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
102
103        if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
104            // HNSW not yet built — flat scan over raw vectors.
105            let (_, raw_vectors) = reader.read_parquet()?;
106            for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
107                all_results.push(SearchResult {
108                    row_id,
109                    distance,
110                    file_path: file_entry.path.clone(),
111                });
112            }
113            continue;
114        }
115
116        let index = reader.load_any_index_for_column(vector_column)?;
117        let local_results = index.search(query, candidate_k, config.ef_search);
118
119        if config.rerank_factor.is_some() {
120            // Read raw F32 vectors for exact distance reranking; file bytes already loaded.
121            let (_, raw_vectors) = reader.read_parquet()?;
122            for (row_id, _approx_dist) in local_results {
123                let idx = row_id.as_u64() as usize;
124                let exact_dist = raw_vectors
125                    .get(idx)
126                    .map(|v| exact_distance(metric, query, v))
127                    .unwrap_or(f32::INFINITY);
128                all_results.push(SearchResult {
129                    row_id,
130                    distance: exact_dist,
131                    file_path: file_entry.path.clone(),
132                });
133            }
134        } else {
135            for (row_id, distance) in local_results {
136                all_results.push(SearchResult {
137                    row_id,
138                    distance,
139                    file_path: file_entry.path.clone(),
140                });
141            }
142        }
143    }
144
145    // Global merge: sort all candidates by distance, keep top-k
146    all_results.sort_by(|a, b| {
147        a.distance
148            .partial_cmp(&b.distance)
149            .unwrap_or(std::cmp::Ordering::Equal)
150    });
151    all_results.truncate(config.top_k);
152    Ok(all_results)
153}
154
155/// Brute-force top-k search over raw vectors. Used for Indexing shards.
156fn flat_search(
157    raw: &[Vec<f32>],
158    query: &[f32],
159    top_k: usize,
160    metric: VectorMetric,
161) -> Vec<(RowId, f32)> {
162    let mut results: Vec<(RowId, f32)> = raw
163        .iter()
164        .enumerate()
165        .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
166        .collect();
167    results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
168    results.truncate(top_k);
169    results
170}
171
172fn parse_metric(s: &str) -> VectorMetric {
173    match s {
174        "euclidean" => VectorMetric::Euclidean,
175        "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
176        _ => VectorMetric::Cosine,
177    }
178}
179
180/// Pre-loaded search session: all HNSW indexes loaded into memory once.
181///
182/// Useful for benchmarks and servers that issue many queries against the same
183/// snapshot. Avoids re-loading and re-deserializing indexes on every call.
184pub struct SearchSession {
185    shards: Vec<LoadedShard>,
186    metric: VectorMetric,
187}
188
189struct LoadedShard {
190    entry: DataFileEntry,
191    /// None when the shard is still being indexed (IndexStatus::Indexing).
192    index: Option<AnyIndex>,
193    /// Raw F32 vectors: always present for Indexing shards (flat scan), optionally
194    /// present for Ready shards when `load_raw = true` (reranking).
195    raw_vectors: Option<Vec<Vec<f32>>>,
196}
197
198impl SearchSession {
199    /// Load all indexes for the latest snapshot into memory.
200    ///
201    /// Pass `load_raw = true` when reranking will be used (`rerank_factor` is
202    /// `Some`); it reads the full parquet columns so exact distances are
203    /// available without extra I/O during `search_query`.
204    pub async fn load(
205        table: &TableIdent,
206        vector_column: &str,
207        dim: u32,
208        catalog: Arc<dyn CatalogProvider>,
209        store: Arc<dyn Store>,
210        load_raw: bool,
211    ) -> AilakeResult<Self> {
212        let all_files = catalog.list_files(table, None).await?;
213        let table_meta = catalog.load_table(table).await?;
214        let metric = parse_metric(
215            table_meta
216                .properties
217                .get("ailake.vector-metric")
218                .map(String::as_str)
219                .unwrap_or("cosine"),
220        );
221
222        let mut shards = Vec::with_capacity(all_files.len());
223        for entry in all_files {
224            let file_bytes: Bytes = store.get(&entry.path).await?;
225            let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
226
227            if entry.index_status == IndexStatus::Indexing {
228                // HNSW not yet built — load raw vectors for flat scan.
229                let (_, raw_vecs) = reader.read_parquet()?;
230                shards.push(LoadedShard {
231                    entry,
232                    index: None,
233                    raw_vectors: Some(raw_vecs),
234                });
235            } else if reader.is_ailake_file() {
236                let mut index = reader.load_any_index_for_column(vector_column)?;
237                let raw_vectors = if load_raw {
238                    index.quantize_to_f16();
239                    let (_, vecs) = reader.read_parquet()?;
240                    Some(vecs)
241                } else {
242                    None
243                };
244                shards.push(LoadedShard {
245                    entry,
246                    index: Some(index),
247                    raw_vectors,
248                });
249            }
250        }
251
252        Ok(Self { shards, metric })
253    }
254
255    /// Number of loaded shards.
256    pub fn shard_count(&self) -> usize {
257        self.shards.len()
258    }
259
260    /// Search multiple queries in one call.
261    ///
262    /// For shards with raw vectors (Indexing or reranking): dispatches to GPU batch
263    /// matmul when a CUDA device is available, falling back to CPU flat scan.
264    /// For indexed shards (HNSW / IVF-PQ): rayon parallel-map over queries — graph
265    /// traversal is inherently sequential and has no GPU batch path.
266    ///
267    /// Returns one `Vec<SearchResult>` per input query, in the same order.
268    pub fn search_batch(
269        &self,
270        queries: &[Vec<f32>],
271        config: &SearchConfig,
272    ) -> Vec<Vec<SearchResult>> {
273        if queries.is_empty() {
274            return vec![];
275        }
276
277        let n_queries = queries.len();
278        let candidate_k = match config.rerank_factor {
279            Some(factor) => config.top_k * factor,
280            None => config.top_k,
281        };
282        let use_nvidia = ailake_index::hardware::detect_cuda();
283        let use_amd = ailake_index::hardware::detect_rocm();
284
285        // Accumulate per-query results across all shards.
286        let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
287
288        for shard in &self.shards {
289            if let Some(raw) = &shard.raw_vectors {
290                // Flat-scan shard — try GPU batch path (NVIDIA first, then AMD ROCm).
291                if !raw.is_empty() {
292                    let dim = raw[0].len();
293                    let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
294                    let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
295                    let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
296
297                    let gpu_batch = if use_nvidia {
298                        ailake_index::gpu::try_nvidia_search_batch(
299                            &q_refs,
300                            &row_ids,
301                            &flat,
302                            dim,
303                            self.metric,
304                            candidate_k,
305                        )
306                    } else if use_amd {
307                        ailake_index::gpu::try_rocm_search_batch(
308                            &q_refs,
309                            &row_ids,
310                            &flat,
311                            dim,
312                            self.metric,
313                            candidate_k,
314                        )
315                    } else {
316                        None
317                    };
318
319                    if let Some(batch) = gpu_batch {
320                        for (qi, results) in batch.into_iter().enumerate() {
321                            for (row_id, distance) in results {
322                                all_results[qi].push(SearchResult {
323                                    row_id,
324                                    distance,
325                                    file_path: shard.entry.path.clone(),
326                                });
327                            }
328                        }
329                        continue;
330                    }
331                }
332
333                // CPU fallback for flat scan.
334                for (qi, query) in queries.iter().enumerate() {
335                    for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
336                        all_results[qi].push(SearchResult {
337                            row_id,
338                            distance,
339                            file_path: shard.entry.path.clone(),
340                        });
341                    }
342                }
343            } else if let Some(index) = &shard.index {
344                // Indexed shard — rayon parallel-map over queries.
345                let shard_results: Vec<Vec<SearchResult>> = queries
346                    .par_iter()
347                    .map(|query| {
348                        index
349                            .search(query, candidate_k, config.ef_search)
350                            .into_iter()
351                            .map(|(row_id, distance)| SearchResult {
352                                row_id,
353                                distance,
354                                file_path: shard.entry.path.clone(),
355                            })
356                            .collect()
357                    })
358                    .collect();
359
360                for (qi, results) in shard_results.into_iter().enumerate() {
361                    all_results[qi].extend(results);
362                }
363            }
364        }
365
366        // Sort + truncate per query.
367        for results in &mut all_results {
368            results.sort_by(|a, b| {
369                a.distance
370                    .partial_cmp(&b.distance)
371                    .unwrap_or(std::cmp::Ordering::Equal)
372            });
373            results.truncate(config.top_k);
374        }
375
376        all_results
377    }
378
379    /// Search using pre-loaded indexes. No I/O — pure in-memory search.
380    pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
381        let candidate_k = match config.rerank_factor {
382            Some(factor) => config.top_k * factor,
383            None => config.top_k,
384        };
385
386        let mut all_results: Vec<SearchResult> = self
387            .shards
388            .par_iter()
389            .flat_map(|shard| {
390                // Geometric pruning per shard.
391                if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
392                    let dist = match self.metric {
393                        VectorMetric::Cosine => {
394                            ailake_vec::cosine_distance(query, &centroid.values)
395                        }
396                        VectorMetric::Euclidean => {
397                            ailake_vec::euclidean_distance(query, &centroid.values)
398                        }
399                        VectorMetric::DotProduct => {
400                            -ailake_vec::dot_product(query, &centroid.values)
401                        }
402                    };
403                    if dist - centroid.radius > config.pruning_threshold {
404                        return vec![];
405                    }
406                }
407
408                if let Some(index) = &shard.index {
409                    // Ready shard: HNSW or IVF-PQ search (dispatched by AnyIndex).
410                    let local_results = index.search(query, candidate_k, config.ef_search);
411                    if config.rerank_factor.is_some() {
412                        if let Some(raw) = &shard.raw_vectors {
413                            local_results
414                                .into_iter()
415                                .map(|(row_id, _approx_dist)| {
416                                    let idx = row_id.as_u64() as usize;
417                                    let exact_dist = raw
418                                        .get(idx)
419                                        .map(|v| exact_distance(self.metric, query, v))
420                                        .unwrap_or(f32::INFINITY);
421                                    SearchResult {
422                                        row_id,
423                                        distance: exact_dist,
424                                        file_path: shard.entry.path.clone(),
425                                    }
426                                })
427                                .collect()
428                        } else {
429                            local_results
430                                .into_iter()
431                                .map(|(row_id, distance)| SearchResult {
432                                    row_id,
433                                    distance,
434                                    file_path: shard.entry.path.clone(),
435                                })
436                                .collect()
437                        }
438                    } else {
439                        local_results
440                            .into_iter()
441                            .map(|(row_id, distance)| SearchResult {
442                                row_id,
443                                distance,
444                                file_path: shard.entry.path.clone(),
445                            })
446                            .collect()
447                    }
448                } else if let Some(raw) = &shard.raw_vectors {
449                    // Indexing shard: exact flat scan.
450                    flat_search(raw, query, candidate_k, self.metric)
451                        .into_iter()
452                        .map(|(row_id, distance)| SearchResult {
453                            row_id,
454                            distance,
455                            file_path: shard.entry.path.clone(),
456                        })
457                        .collect()
458                } else {
459                    vec![]
460                }
461            })
462            .collect();
463
464        all_results.sort_by(|a, b| {
465            a.distance
466                .partial_cmp(&b.distance)
467                .unwrap_or(std::cmp::Ordering::Equal)
468        });
469        all_results.truncate(config.top_k);
470        all_results
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use ailake_catalog::{HadoopCatalog, TableIdent};
478    use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
479    use ailake_store::LocalStore;
480    use arrow_array::{Int32Array, RecordBatch};
481    use arrow_schema::{DataType, Field, Schema};
482    use std::sync::Arc;
483    use tempfile::TempDir;
484
485    fn make_policy(dim: u32) -> VectorStoragePolicy {
486        VectorStoragePolicy {
487            column_name: "embedding".to_string(),
488            dim,
489            metric: VectorMetric::Cosine,
490            precision: VectorPrecision::F16,
491            pq: None,
492            keep_raw_for_reranking: false,
493        }
494    }
495
496    async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
497        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
498        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
499        let table = TableIdent::new("default", "table");
500
501        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
502        let ids: Vec<i32> = (0..rows as i32).collect();
503        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
504
505        // Each row i has embedding with 1.0 at dimension i and 0 elsewhere (unit basis vectors)
506        let embeddings: Vec<Vec<f32>> = (0..rows)
507            .map(|i| {
508                let mut v = vec![0.0f32; dim];
509                v[i % dim] = 1.0;
510                v
511            })
512            .collect();
513
514        let mut writer =
515            crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
516                .await
517                .unwrap();
518        writer.write_batch(&batch, &embeddings).await.unwrap();
519        writer.commit().await.unwrap();
520    }
521
522    #[tokio::test]
523    async fn rerank_returns_correct_top_k_count() {
524        let dir = TempDir::new().unwrap();
525        let dim = 8usize;
526        write_demo_table(&dir, dim, 8).await;
527
528        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
529        let catalog: Arc<dyn CatalogProvider> =
530            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
531        let table = TableIdent::new("default", "table");
532
533        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
534        let config = SearchConfig {
535            top_k: 3,
536            ef_search: 50,
537            pruning_threshold: f32::INFINITY,
538            rerank_factor: Some(2),
539        };
540
541        let results = search(
542            &table,
543            &query,
544            config,
545            "embedding",
546            dim as u32,
547            catalog,
548            store,
549        )
550        .await
551        .unwrap();
552
553        assert_eq!(results.len(), 3);
554    }
555
556    #[tokio::test]
557    async fn rerank_nearest_is_exact_match() {
558        let dir = TempDir::new().unwrap();
559        let dim = 8usize;
560        write_demo_table(&dir, dim, 8).await;
561
562        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
563        let catalog: Arc<dyn CatalogProvider> =
564            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
565        let table = TableIdent::new("default", "table");
566
567        // Row 0 has [1,0,0,...] — cosine distance to same query is 0
568        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
569        let config = SearchConfig {
570            top_k: 1,
571            ef_search: 50,
572            pruning_threshold: f32::INFINITY,
573            rerank_factor: Some(4),
574        };
575
576        let results = search(
577            &table,
578            &query,
579            config,
580            "embedding",
581            dim as u32,
582            catalog,
583            store,
584        )
585        .await
586        .unwrap();
587
588        assert_eq!(results.len(), 1);
589        // Exact cosine distance between identical unit vectors is ~0 (F16 rounding allowed)
590        assert!(
591            results[0].distance < 1e-3,
592            "distance was {}",
593            results[0].distance
594        );
595        assert_eq!(results[0].row_id, RowId::new(0));
596    }
597
598    #[tokio::test]
599    async fn no_rerank_matches_default_behavior() {
600        let dir = TempDir::new().unwrap();
601        let dim = 4usize;
602        write_demo_table(&dir, dim, 4).await;
603
604        let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
605        let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
606        let cat_a: Arc<dyn CatalogProvider> =
607            Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
608        let cat_b: Arc<dyn CatalogProvider> =
609            Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
610        let table = TableIdent::new("default", "table");
611
612        let query = vec![1.0f32, 0.0, 0.0, 0.0];
613        let cfg_plain = SearchConfig {
614            top_k: 2,
615            ef_search: 50,
616            pruning_threshold: f32::INFINITY,
617            rerank_factor: None,
618        };
619        let cfg_rerank = SearchConfig {
620            top_k: 2,
621            ef_search: 50,
622            pruning_threshold: f32::INFINITY,
623            rerank_factor: Some(2),
624        };
625
626        let plain = search(
627            &table,
628            &query,
629            cfg_plain,
630            "embedding",
631            dim as u32,
632            cat_a,
633            store_a,
634        )
635        .await
636        .unwrap();
637        let reranked = search(
638            &table,
639            &query,
640            cfg_rerank,
641            "embedding",
642            dim as u32,
643            cat_b,
644            store_b,
645        )
646        .await
647        .unwrap();
648
649        // Both should return same top-1 result (row 0, distance ~0)
650        assert_eq!(plain[0].row_id, reranked[0].row_id);
651    }
652}