Skip to main content

ailake_query/
scanner.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::Arc;
3
4use rayon::prelude::*;
5use tracing::{debug, error};
6
7use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
8use ailake_core::{AilakeError, AilakeResult, RowId, VectorMetric};
9use ailake_file::AilakeFileReader;
10use ailake_index::AnyIndex;
11use ailake_store::Store;
12use ailake_vec::exact_distance;
13use arrow_array::RecordBatch;
14use bytes::Bytes;
15
16use crate::pruner::VectorPruner;
17
18#[derive(Debug, Clone)]
19pub struct SearchConfig {
20    pub top_k: usize,
21    pub ef_search: usize,
22    /// Maximum distance from query to file centroid edge for a file to be searched.
23    /// Files where `distance(query, centroid) - radius > pruning_threshold` are skipped.
24    /// Set to `f32::INFINITY` to disable pruning (scan all files).
25    pub pruning_threshold: f32,
26    /// When `Some(factor)`, fetch `top_k * factor` candidates from the HNSW index and
27    /// rerank them using exact F32 distances before truncating to `top_k`.
28    /// Corrects the approximation error introduced by PQ-compressed HNSW distances.
29    /// `None` (default) disables reranking.
30    pub rerank_factor: Option<usize>,
31}
32
33impl Default for SearchConfig {
34    fn default() -> Self {
35        Self {
36            top_k: 10,
37            ef_search: 50,
38            pruning_threshold: f32::INFINITY,
39            rerank_factor: None,
40        }
41    }
42}
43
44impl SearchConfig {
45    pub fn with_pruning(mut self, threshold: f32) -> Self {
46        self.pruning_threshold = threshold;
47        self
48    }
49
50    pub fn with_reranking(mut self, factor: usize) -> Self {
51        self.rerank_factor = Some(factor);
52        self
53    }
54}
55
56#[derive(Debug)]
57pub struct SearchResult {
58    pub row_id: RowId,
59    pub distance: f32,
60    pub file_path: String,
61}
62
63/// Search across all files in the latest snapshot, with geometric pruning.
64///
65/// Flow:
66/// 1. Load file list from catalog (includes centroid metadata)
67/// 2. Prune files whose centroid + radius cannot contain a result within `pruning_threshold`
68/// 3. For surviving files: load bytes, deserialize HNSW, run top-k search
69/// 4. Global merge of all per-file top-k lists, return global top-k
70pub async fn search(
71    table: &TableIdent,
72    query: &[f32],
73    config: SearchConfig,
74    vector_column: &str,
75    dim: u32,
76    catalog: Arc<dyn CatalogProvider>,
77    store: Arc<dyn Store>,
78) -> AilakeResult<Vec<SearchResult>> {
79    // Get file metadata (includes centroid info) without reading any data files
80    let all_files = catalog.list_files(table, None).await?;
81
82    // Determine vector metric from table metadata for correct distance computation
83    let table_meta = catalog.load_table(table).await?;
84    let metric = parse_metric(
85        table_meta
86            .properties
87            .get("ailake.vector-metric")
88            .map(String::as_str)
89            .unwrap_or("cosine"),
90    );
91
92    // Geometric pruning: skip files whose centroid is too far from the query
93    let total_files = all_files.len();
94    let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
95    debug!(
96        "ailake: geometric pruning — {}/{} files survive (threshold={})",
97        surviving_files.len(),
98        total_files,
99        config.pruning_threshold
100    );
101
102    let candidate_k = match config.rerank_factor {
103        Some(factor) => config.top_k * factor,
104        None => config.top_k,
105    };
106
107    let mut all_results: Vec<SearchResult> = Vec::new();
108
109    for file_entry in &surviving_files {
110        let file_bytes: Bytes = store.get(&file_entry.path).await?;
111        let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
112
113        if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
114            // HNSW not yet built — flat scan over raw vectors.
115            debug!(
116                "ailake: flat scan fallback for {} (index_status={:?})",
117                file_entry.path, file_entry.index_status
118            );
119            let (_, raw_vectors) = reader.read_parquet()?;
120            for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
121                all_results.push(SearchResult {
122                    row_id,
123                    distance,
124                    file_path: file_entry.path.clone(),
125                });
126            }
127            continue;
128        }
129
130        let index = reader.load_any_index_for_column(vector_column)?;
131        let local_results = index.search(query, candidate_k, config.ef_search);
132
133        if config.rerank_factor.is_some() {
134            // Read raw F32 vectors for exact distance reranking; file bytes already loaded.
135            let (_, raw_vectors) = reader.read_parquet()?;
136            for (row_id, _approx_dist) in local_results {
137                let idx = row_id.as_u64() as usize;
138                let exact_dist = match raw_vectors.get(idx) {
139                    Some(v) => exact_distance(metric, query, v),
140                    None => {
141                        error!(
142                            "ailake: invariant violated — row_id {} out of bounds \
143                             (raw_vectors.len={}, file={}); \
144                             Parquet row count and HNSW node count are out of sync; \
145                             file may be corrupt — run compaction to rebuild",
146                            idx,
147                            raw_vectors.len(),
148                            file_entry.path
149                        );
150                        f32::INFINITY
151                    }
152                };
153                all_results.push(SearchResult {
154                    row_id,
155                    distance: exact_dist,
156                    file_path: file_entry.path.clone(),
157                });
158            }
159        } else {
160            for (row_id, distance) in local_results {
161                all_results.push(SearchResult {
162                    row_id,
163                    distance,
164                    file_path: file_entry.path.clone(),
165                });
166            }
167        }
168    }
169
170    // Global merge: sort all candidates by distance, keep top-k
171    all_results.sort_by(|a, b| {
172        a.distance
173            .partial_cmp(&b.distance)
174            .unwrap_or(std::cmp::Ordering::Equal)
175    });
176    all_results.truncate(config.top_k);
177    Ok(all_results)
178}
179
180/// Brute-force top-k search over raw vectors. Used for Indexing shards.
181fn flat_search(
182    raw: &[Vec<f32>],
183    query: &[f32],
184    top_k: usize,
185    metric: VectorMetric,
186) -> Vec<(RowId, f32)> {
187    let mut results: Vec<(RowId, f32)> = raw
188        .iter()
189        .enumerate()
190        .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
191        .collect();
192    results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
193    results.truncate(top_k);
194    results
195}
196
197fn parse_metric(s: &str) -> VectorMetric {
198    match s {
199        "euclidean" => VectorMetric::Euclidean,
200        "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
201        _ => VectorMetric::Cosine,
202    }
203}
204
205/// Pre-loaded search session: all HNSW indexes loaded into memory once.
206///
207/// Useful for benchmarks and servers that issue many queries against the same
208/// snapshot. Avoids re-loading and re-deserializing indexes on every call.
209pub struct SearchSession {
210    shards: Vec<LoadedShard>,
211    metric: VectorMetric,
212}
213
214struct LoadedShard {
215    entry: DataFileEntry,
216    /// None when the shard is still being indexed (IndexStatus::Indexing).
217    index: Option<AnyIndex>,
218    /// Raw F32 vectors: always present for Indexing shards (flat scan), optionally
219    /// present for Ready shards when `load_raw = true` (reranking).
220    raw_vectors: Option<Vec<Vec<f32>>>,
221}
222
223impl SearchSession {
224    /// Load all indexes for the latest snapshot into memory.
225    ///
226    /// Pass `load_raw = true` when reranking will be used (`rerank_factor` is
227    /// `Some`); it reads the full parquet columns so exact distances are
228    /// available without extra I/O during `search_query`.
229    pub async fn load(
230        table: &TableIdent,
231        vector_column: &str,
232        dim: u32,
233        catalog: Arc<dyn CatalogProvider>,
234        store: Arc<dyn Store>,
235        load_raw: bool,
236    ) -> AilakeResult<Self> {
237        let all_files = catalog.list_files(table, None).await?;
238        let table_meta = catalog.load_table(table).await?;
239        let metric = parse_metric(
240            table_meta
241                .properties
242                .get("ailake.vector-metric")
243                .map(String::as_str)
244                .unwrap_or("cosine"),
245        );
246
247        let mut shards = Vec::with_capacity(all_files.len());
248        for entry in all_files {
249            let file_bytes: Bytes = store.get(&entry.path).await?;
250            let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
251
252            if entry.index_status == IndexStatus::Indexing {
253                // HNSW not yet built — load raw vectors for flat scan.
254                let (_, raw_vecs) = reader.read_parquet()?;
255                shards.push(LoadedShard {
256                    entry,
257                    index: None,
258                    raw_vectors: Some(raw_vecs),
259                });
260            } else if reader.is_ailake_file() {
261                let mut index = reader.load_any_index_for_column(vector_column)?;
262                let raw_vectors = if load_raw {
263                    index.quantize_to_f16();
264                    let (_, vecs) = reader.read_parquet()?;
265                    Some(vecs)
266                } else {
267                    None
268                };
269                shards.push(LoadedShard {
270                    entry,
271                    index: Some(index),
272                    raw_vectors,
273                });
274            }
275        }
276
277        Ok(Self { shards, metric })
278    }
279
280    /// Number of loaded shards.
281    pub fn shard_count(&self) -> usize {
282        self.shards.len()
283    }
284
285    /// Search multiple queries in one call.
286    ///
287    /// For shards with raw vectors (Indexing or reranking): dispatches to GPU batch
288    /// matmul when a CUDA device is available, falling back to CPU flat scan.
289    /// For indexed shards (HNSW / IVF-PQ): rayon parallel-map over queries — graph
290    /// traversal is inherently sequential and has no GPU batch path.
291    ///
292    /// Returns one `Vec<SearchResult>` per input query, in the same order.
293    pub fn search_batch(
294        &self,
295        queries: &[Vec<f32>],
296        config: &SearchConfig,
297    ) -> Vec<Vec<SearchResult>> {
298        if queries.is_empty() {
299            return vec![];
300        }
301
302        let n_queries = queries.len();
303        let candidate_k = match config.rerank_factor {
304            Some(factor) => config.top_k * factor,
305            None => config.top_k,
306        };
307        let use_nvidia = ailake_index::hardware::detect_cuda();
308        let use_amd = ailake_index::hardware::detect_rocm();
309
310        // Accumulate per-query results across all shards.
311        let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
312
313        for shard in &self.shards {
314            if let Some(raw) = &shard.raw_vectors {
315                // Flat-scan shard — try GPU batch path (NVIDIA first, then AMD ROCm).
316                if !raw.is_empty() {
317                    let dim = raw[0].len();
318                    let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
319                    let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
320                    let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
321
322                    let gpu_batch = if use_nvidia {
323                        ailake_index::gpu::try_nvidia_search_batch(
324                            &q_refs,
325                            &row_ids,
326                            &flat,
327                            dim,
328                            self.metric,
329                            candidate_k,
330                        )
331                    } else if use_amd {
332                        ailake_index::gpu::try_rocm_search_batch(
333                            &q_refs,
334                            &row_ids,
335                            &flat,
336                            dim,
337                            self.metric,
338                            candidate_k,
339                        )
340                    } else {
341                        None
342                    };
343
344                    if let Some(batch) = gpu_batch {
345                        for (qi, results) in batch.into_iter().enumerate() {
346                            for (row_id, distance) in results {
347                                all_results[qi].push(SearchResult {
348                                    row_id,
349                                    distance,
350                                    file_path: shard.entry.path.clone(),
351                                });
352                            }
353                        }
354                        continue;
355                    }
356                }
357
358                // CPU fallback for flat scan.
359                for (qi, query) in queries.iter().enumerate() {
360                    for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
361                        all_results[qi].push(SearchResult {
362                            row_id,
363                            distance,
364                            file_path: shard.entry.path.clone(),
365                        });
366                    }
367                }
368            } else if let Some(index) = &shard.index {
369                // Indexed shard — rayon parallel-map over queries.
370                let shard_results: Vec<Vec<SearchResult>> = queries
371                    .par_iter()
372                    .map(|query| {
373                        index
374                            .search(query, candidate_k, config.ef_search)
375                            .into_iter()
376                            .map(|(row_id, distance)| SearchResult {
377                                row_id,
378                                distance,
379                                file_path: shard.entry.path.clone(),
380                            })
381                            .collect()
382                    })
383                    .collect();
384
385                for (qi, results) in shard_results.into_iter().enumerate() {
386                    all_results[qi].extend(results);
387                }
388            }
389        }
390
391        // Sort + truncate per query.
392        for results in &mut all_results {
393            results.sort_by(|a, b| {
394                a.distance
395                    .partial_cmp(&b.distance)
396                    .unwrap_or(std::cmp::Ordering::Equal)
397            });
398            results.truncate(config.top_k);
399        }
400
401        all_results
402    }
403
404    /// Search using pre-loaded indexes. No I/O — pure in-memory search.
405    pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
406        let candidate_k = match config.rerank_factor {
407            Some(factor) => config.top_k * factor,
408            None => config.top_k,
409        };
410
411        let mut all_results: Vec<SearchResult> = self
412            .shards
413            .par_iter()
414            .flat_map(|shard| {
415                // Geometric pruning per shard.
416                if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
417                    let dist = match self.metric {
418                        VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
419                            ailake_vec::cosine_distance(query, &centroid.values)
420                        }
421                        VectorMetric::Euclidean => {
422                            ailake_vec::euclidean_distance(query, &centroid.values)
423                        }
424                        VectorMetric::DotProduct => {
425                            -ailake_vec::dot_product(query, &centroid.values)
426                        }
427                    };
428                    if dist - centroid.radius > config.pruning_threshold {
429                        return vec![];
430                    }
431                }
432
433                if let Some(index) = &shard.index {
434                    // Ready shard: HNSW or IVF-PQ search (dispatched by AnyIndex).
435                    let local_results = index.search(query, candidate_k, config.ef_search);
436                    if config.rerank_factor.is_some() {
437                        if let Some(raw) = &shard.raw_vectors {
438                            local_results
439                                .into_iter()
440                                .map(|(row_id, _approx_dist)| {
441                                    let idx = row_id.as_u64() as usize;
442                                    let exact_dist = raw
443                                        .get(idx)
444                                        .map(|v| exact_distance(self.metric, query, v))
445                                        .unwrap_or(f32::INFINITY);
446                                    SearchResult {
447                                        row_id,
448                                        distance: exact_dist,
449                                        file_path: shard.entry.path.clone(),
450                                    }
451                                })
452                                .collect()
453                        } else {
454                            local_results
455                                .into_iter()
456                                .map(|(row_id, distance)| SearchResult {
457                                    row_id,
458                                    distance,
459                                    file_path: shard.entry.path.clone(),
460                                })
461                                .collect()
462                        }
463                    } else {
464                        local_results
465                            .into_iter()
466                            .map(|(row_id, distance)| SearchResult {
467                                row_id,
468                                distance,
469                                file_path: shard.entry.path.clone(),
470                            })
471                            .collect()
472                    }
473                } else if let Some(raw) = &shard.raw_vectors {
474                    // Indexing shard: exact flat scan.
475                    flat_search(raw, query, candidate_k, self.metric)
476                        .into_iter()
477                        .map(|(row_id, distance)| SearchResult {
478                            row_id,
479                            distance,
480                            file_path: shard.entry.path.clone(),
481                        })
482                        .collect()
483                } else {
484                    vec![]
485                }
486            })
487            .collect();
488
489        all_results.sort_by(|a, b| {
490            a.distance
491                .partial_cmp(&b.distance)
492                .unwrap_or(std::cmp::Ordering::Equal)
493        });
494        all_results.truncate(config.top_k);
495        all_results
496    }
497}
498
499/// Fetch full row data for a slice of search results.
500///
501/// Groups results by Parquet file, reads each file once, extracts the matching rows
502/// via `arrow_select::take`, then concatenates everything back in original top-k order
503/// with a `_distance: Float32` column appended.
504///
505/// Use this immediately after `search()` to retrieve the actual text / metadata
506/// columns (e.g. `chunk_text`, `document_title`) alongside the distance scores.
507pub async fn fetch_rows(
508    results: &[SearchResult],
509    store: Arc<dyn Store>,
510    vector_column: &str,
511    dim: u32,
512) -> AilakeResult<RecordBatch> {
513    use std::collections::HashMap;
514
515    use arrow_array::{ArrayRef, Float32Array, UInt32Array};
516    use arrow_schema::{DataType, Field, Schema};
517    use arrow_select::{concat::concat_batches, take::take};
518
519    if results.is_empty() {
520        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
521    }
522
523    // Group by file path; preserve original position for re-sorting.
524    let mut by_file: HashMap<&str, Vec<(u64, f32, usize)>> = HashMap::new();
525    for (i, r) in results.iter().enumerate() {
526        by_file
527            .entry(r.file_path.as_str())
528            .or_default()
529            .push((r.row_id.as_u64(), r.distance, i));
530    }
531
532    use arrow_array::FixedSizeListArray;
533
534    // (original_index, distance, single-row RecordBatch, decoded F32 vector)
535    let mut collected: Vec<(usize, f32, RecordBatch, Vec<f32>)> = Vec::with_capacity(results.len());
536
537    for (file_path, rows) in &by_file {
538        let bytes = store.get(file_path).await?;
539        let reader = AilakeFileReader::new(bytes, vector_column, dim);
540        let (batch, vectors) = reader.read_parquet()?;
541
542        for &(row_id, distance, pos) in rows {
543            let idx = row_id as usize;
544            if idx >= batch.num_rows() {
545                tracing::warn!(
546                    "fetch_rows: row_id {} out of bounds (file_rows={}, file={}), skipping",
547                    idx,
548                    batch.num_rows(),
549                    file_path
550                );
551                continue;
552            }
553
554            let indices = UInt32Array::from(vec![idx as u32]);
555            let row_cols: Vec<ArrayRef> = batch
556                .columns()
557                .iter()
558                .map(|col| {
559                    take(col.as_ref(), &indices, None)
560                        .map_err(|e| AilakeError::Arrow(e.to_string()))
561                })
562                .collect::<AilakeResult<Vec<_>>>()?;
563
564            let row_batch = RecordBatch::try_new(batch.schema(), row_cols)
565                .map_err(|e| AilakeError::Arrow(e.to_string()))?;
566
567            // Capture decoded F32 vector for this row (empty vec if not available).
568            let vec = vectors
569                .get(idx)
570                .cloned()
571                .unwrap_or_else(|| vec![0.0f32; dim as usize]);
572
573            collected.push((pos, distance, row_batch, vec));
574        }
575    }
576
577    if collected.is_empty() {
578        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
579    }
580
581    // Restore original top-k order from the search results slice.
582    collected.sort_by_key(|(pos, _, _, _)| *pos);
583
584    let distances: Vec<f32> = collected.iter().map(|(_, d, _, _)| *d).collect();
585    let row_batches: Vec<&RecordBatch> = collected.iter().map(|(_, _, b, _)| b).collect();
586    let base_schema = collected[0].2.schema();
587
588    let combined =
589        concat_batches(&base_schema, row_batches).map_err(|e| AilakeError::Arrow(e.to_string()))?;
590
591    // Build FixedSizeList<Float32> column with decoded vectors (F32, not raw F16 bytes).
592    let flat_vecs: Vec<f32> = collected
593        .iter()
594        .flat_map(|(_, _, _, v)| v.iter().copied())
595        .collect();
596    let item_field = Arc::new(Field::new("item", DataType::Float32, false));
597    let values_arr = Arc::new(Float32Array::from(flat_vecs)) as ArrayRef;
598    let vec_col = FixedSizeListArray::new(item_field.clone(), dim as i32, values_arr, None);
599    let vec_field = Arc::new(Field::new(
600        vector_column,
601        DataType::FixedSizeList(item_field, dim as i32),
602        false,
603    ));
604
605    // Schema: tabular cols, then decoded vector col, then _distance.
606    let mut fields: Vec<Arc<Field>> = base_schema.fields().to_vec();
607    fields.push(vec_field);
608    fields.push(Arc::new(Field::new("_distance", DataType::Float32, false)));
609    let new_schema = Arc::new(Schema::new(fields));
610
611    let mut columns: Vec<ArrayRef> = combined.columns().to_vec();
612    columns.push(Arc::new(vec_col));
613    columns.push(Arc::new(Float32Array::from(distances)));
614
615    RecordBatch::try_new(new_schema, columns).map_err(|e| AilakeError::Arrow(e.to_string()))
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621    use ailake_catalog::{HadoopCatalog, TableIdent};
622    use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
623    use ailake_store::LocalStore;
624    use arrow_array::{Int32Array, RecordBatch};
625    use arrow_schema::{DataType, Field, Schema};
626    use std::sync::Arc;
627    use tempfile::TempDir;
628
629    fn make_policy(dim: u32) -> VectorStoragePolicy {
630        VectorStoragePolicy {
631            column_name: "embedding".to_string(),
632            dim,
633            metric: VectorMetric::Cosine,
634            precision: VectorPrecision::F16,
635            pq: None,
636            keep_raw_for_reranking: true,
637            pre_normalize: false,
638            hnsw_m: None,
639            hnsw_ef_construction: None,
640        }
641    }
642
643    async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
644        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
645        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
646        let table = TableIdent::new("default", "table");
647
648        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
649        let ids: Vec<i32> = (0..rows as i32).collect();
650        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
651
652        // Each row i has embedding with 1.0 at dimension i and 0 elsewhere (unit basis vectors)
653        let embeddings: Vec<Vec<f32>> = (0..rows)
654            .map(|i| {
655                let mut v = vec![0.0f32; dim];
656                v[i % dim] = 1.0;
657                v
658            })
659            .collect();
660
661        let mut writer =
662            crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
663                .await
664                .unwrap();
665        writer.write_batch(&batch, &embeddings).await.unwrap();
666        writer.commit().await.unwrap();
667    }
668
669    #[tokio::test]
670    async fn rerank_returns_correct_top_k_count() {
671        let dir = TempDir::new().unwrap();
672        let dim = 8usize;
673        write_demo_table(&dir, dim, 8).await;
674
675        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
676        let catalog: Arc<dyn CatalogProvider> =
677            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
678        let table = TableIdent::new("default", "table");
679
680        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
681        let config = SearchConfig {
682            top_k: 3,
683            ef_search: 50,
684            pruning_threshold: f32::INFINITY,
685            rerank_factor: Some(2),
686        };
687
688        let results = search(
689            &table,
690            &query,
691            config,
692            "embedding",
693            dim as u32,
694            catalog,
695            store,
696        )
697        .await
698        .unwrap();
699
700        assert_eq!(results.len(), 3);
701    }
702
703    #[tokio::test]
704    async fn rerank_nearest_is_exact_match() {
705        let dir = TempDir::new().unwrap();
706        let dim = 8usize;
707        write_demo_table(&dir, dim, 8).await;
708
709        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
710        let catalog: Arc<dyn CatalogProvider> =
711            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
712        let table = TableIdent::new("default", "table");
713
714        // Row 0 has [1,0,0,...] — cosine distance to same query is 0
715        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
716        let config = SearchConfig {
717            top_k: 1,
718            ef_search: 50,
719            pruning_threshold: f32::INFINITY,
720            rerank_factor: Some(4),
721        };
722
723        let results = search(
724            &table,
725            &query,
726            config,
727            "embedding",
728            dim as u32,
729            catalog,
730            store,
731        )
732        .await
733        .unwrap();
734
735        assert_eq!(results.len(), 1);
736        // Exact cosine distance between identical unit vectors is ~0 (F16 rounding allowed)
737        assert!(
738            results[0].distance < 1e-3,
739            "distance was {}",
740            results[0].distance
741        );
742        assert_eq!(results[0].row_id, RowId::new(0));
743    }
744
745    #[tokio::test]
746    async fn no_rerank_matches_default_behavior() {
747        let dir = TempDir::new().unwrap();
748        let dim = 4usize;
749        write_demo_table(&dir, dim, 4).await;
750
751        let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
752        let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
753        let cat_a: Arc<dyn CatalogProvider> =
754            Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
755        let cat_b: Arc<dyn CatalogProvider> =
756            Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
757        let table = TableIdent::new("default", "table");
758
759        let query = vec![1.0f32, 0.0, 0.0, 0.0];
760        let cfg_plain = SearchConfig {
761            top_k: 2,
762            ef_search: 50,
763            pruning_threshold: f32::INFINITY,
764            rerank_factor: None,
765        };
766        let cfg_rerank = SearchConfig {
767            top_k: 2,
768            ef_search: 50,
769            pruning_threshold: f32::INFINITY,
770            rerank_factor: Some(2),
771        };
772
773        let plain = search(
774            &table,
775            &query,
776            cfg_plain,
777            "embedding",
778            dim as u32,
779            cat_a,
780            store_a,
781        )
782        .await
783        .unwrap();
784        let reranked = search(
785            &table,
786            &query,
787            cfg_rerank,
788            "embedding",
789            dim as u32,
790            cat_b,
791            store_b,
792        )
793        .await
794        .unwrap();
795
796        // Both should return same top-1 result (row 0, distance ~0)
797        assert_eq!(plain[0].row_id, reranked[0].row_id);
798    }
799}