Skip to main content

ailake_query/
scanner.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::Arc;
3
4use rayon::prelude::*;
5use tracing::{debug, error};
6
7use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
8use ailake_core::{AilakeError, AilakeResult, EmbeddingModelInfo, RowId, VectorMetric};
9use ailake_file::AilakeFileReader;
10use ailake_index::AnyIndex;
11use ailake_store::Store;
12use ailake_vec::exact_distance;
13use arrow_array::RecordBatch;
14use bytes::Bytes;
15
16use crate::pruner::VectorPruner;
17
18#[derive(Debug, Clone)]
19pub struct SearchConfig {
20    pub top_k: usize,
21    pub ef_search: usize,
22    /// Maximum distance from query to file centroid edge for a file to be searched.
23    /// Files where `distance(query, centroid) - radius > pruning_threshold` are skipped.
24    /// Set to `f32::INFINITY` to disable pruning (scan all files).
25    pub pruning_threshold: f32,
26    /// When `Some(factor)`, fetch `top_k * factor` candidates from the HNSW index and
27    /// rerank them using exact F32 distances before truncating to `top_k`.
28    /// Corrects the approximation error introduced by PQ-compressed HNSW distances.
29    /// `None` (default) disables reranking.
30    pub rerank_factor: Option<usize>,
31}
32
33impl Default for SearchConfig {
34    fn default() -> Self {
35        Self {
36            top_k: 10,
37            ef_search: 50,
38            pruning_threshold: f32::INFINITY,
39            rerank_factor: None,
40        }
41    }
42}
43
44impl SearchConfig {
45    pub fn with_pruning(mut self, threshold: f32) -> Self {
46        self.pruning_threshold = threshold;
47        self
48    }
49
50    pub fn with_reranking(mut self, factor: usize) -> Self {
51        self.rerank_factor = Some(factor);
52        self
53    }
54}
55
56#[derive(Debug)]
57pub struct SearchResult {
58    pub row_id: RowId,
59    pub distance: f32,
60    pub file_path: String,
61}
62
63/// Search across all files in the latest snapshot, with geometric pruning.
64///
65/// Flow:
66/// 1. Load file list from catalog (includes centroid metadata)
67/// 2. Prune files whose centroid + radius cannot contain a result within `pruning_threshold`
68/// 3. For surviving files: load bytes, deserialize HNSW, run top-k search
69/// 4. Global merge of all per-file top-k lists, return global top-k
70pub async fn search(
71    table: &TableIdent,
72    query: &[f32],
73    config: SearchConfig,
74    vector_column: &str,
75    dim: u32,
76    catalog: Arc<dyn CatalogProvider>,
77    store: Arc<dyn Store>,
78) -> AilakeResult<Vec<SearchResult>> {
79    // Get file metadata (includes centroid info) without reading any data files
80    let all_files = catalog.list_files(table, None).await?;
81
82    // Determine vector metric from table metadata for correct distance computation
83    let table_meta = catalog.load_table(table).await?;
84
85    // Validate query dim against table dim — silent dim mismatch gives garbage results.
86    if let Some(table_dim_str) = table_meta.properties.get("ailake.vector-dim") {
87        if let Ok(table_dim) = table_dim_str.parse::<u32>() {
88            let query_dim = query.len() as u32;
89            if query_dim != table_dim {
90                let table_model = table_meta
91                    .properties
92                    .get(EmbeddingModelInfo::property_key())
93                    .cloned()
94                    .unwrap_or_else(|| format!("dim={}", table_dim));
95                return Err(AilakeError::ModelMismatch {
96                    table_model,
97                    table_dim,
98                    batch_model: format!("query dim={}", query_dim),
99                    batch_dim: query_dim,
100                });
101            }
102        }
103    }
104
105    let metric = parse_metric(
106        table_meta
107            .properties
108            .get("ailake.vector-metric")
109            .map(String::as_str)
110            .unwrap_or("cosine"),
111    );
112
113    // Geometric pruning: skip files whose centroid is too far from the query
114    let total_files = all_files.len();
115    let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
116    debug!(
117        "ailake: geometric pruning — {}/{} files survive (threshold={})",
118        surviving_files.len(),
119        total_files,
120        config.pruning_threshold
121    );
122
123    let candidate_k = match config.rerank_factor {
124        Some(factor) => config.top_k * factor,
125        None => config.top_k,
126    };
127
128    let mut all_results: Vec<SearchResult> = Vec::new();
129
130    for file_entry in &surviving_files {
131        let file_bytes: Bytes = store.get(&file_entry.path).await?;
132        let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
133
134        if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
135            // HNSW not yet built — flat scan over raw vectors.
136            debug!(
137                "ailake: flat scan fallback for {} (index_status={:?})",
138                file_entry.path, file_entry.index_status
139            );
140            let (_, raw_vectors) = reader.read_parquet()?;
141            for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
142                all_results.push(SearchResult {
143                    row_id,
144                    distance,
145                    file_path: file_entry.path.clone(),
146                });
147            }
148            continue;
149        }
150
151        let index = reader.load_any_index_for_column(vector_column)?;
152        let local_results = index.search(query, candidate_k, config.ef_search);
153
154        if config.rerank_factor.is_some() {
155            // Read raw F32 vectors for exact distance reranking; file bytes already loaded.
156            let (_, raw_vectors) = reader.read_parquet()?;
157            for (row_id, _approx_dist) in local_results {
158                let idx = row_id.as_u64() as usize;
159                let exact_dist = match raw_vectors.get(idx) {
160                    Some(v) => exact_distance(metric, query, v),
161                    None => {
162                        error!(
163                            "ailake: invariant violated — row_id {} out of bounds \
164                             (raw_vectors.len={}, file={}); \
165                             Parquet row count and HNSW node count are out of sync; \
166                             file may be corrupt — run compaction to rebuild",
167                            idx,
168                            raw_vectors.len(),
169                            file_entry.path
170                        );
171                        f32::INFINITY
172                    }
173                };
174                all_results.push(SearchResult {
175                    row_id,
176                    distance: exact_dist,
177                    file_path: file_entry.path.clone(),
178                });
179            }
180        } else {
181            for (row_id, distance) in local_results {
182                all_results.push(SearchResult {
183                    row_id,
184                    distance,
185                    file_path: file_entry.path.clone(),
186                });
187            }
188        }
189    }
190
191    // Global merge: sort all candidates by distance, keep top-k
192    all_results.sort_by(|a, b| {
193        a.distance
194            .partial_cmp(&b.distance)
195            .unwrap_or(std::cmp::Ordering::Equal)
196    });
197    all_results.truncate(config.top_k);
198    Ok(all_results)
199}
200
201/// Brute-force top-k search over raw vectors. Used for Indexing shards.
202fn flat_search(
203    raw: &[Vec<f32>],
204    query: &[f32],
205    top_k: usize,
206    metric: VectorMetric,
207) -> Vec<(RowId, f32)> {
208    let mut results: Vec<(RowId, f32)> = raw
209        .iter()
210        .enumerate()
211        .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
212        .collect();
213    results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
214    results.truncate(top_k);
215    results
216}
217
218fn parse_metric(s: &str) -> VectorMetric {
219    match s {
220        "euclidean" => VectorMetric::Euclidean,
221        "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
222        _ => VectorMetric::Cosine,
223    }
224}
225
226/// Pre-loaded search session: all HNSW indexes loaded into memory once.
227///
228/// Useful for benchmarks and servers that issue many queries against the same
229/// snapshot. Avoids re-loading and re-deserializing indexes on every call.
230pub struct SearchSession {
231    shards: Vec<LoadedShard>,
232    metric: VectorMetric,
233}
234
235struct LoadedShard {
236    entry: DataFileEntry,
237    /// None when the shard is still being indexed (IndexStatus::Indexing).
238    index: Option<AnyIndex>,
239    /// Raw F32 vectors: always present for Indexing shards (flat scan), optionally
240    /// present for Ready shards when `load_raw = true` (reranking).
241    raw_vectors: Option<Vec<Vec<f32>>>,
242}
243
244impl SearchSession {
245    /// Load all indexes for the latest snapshot into memory.
246    ///
247    /// Pass `load_raw = true` when reranking will be used (`rerank_factor` is
248    /// `Some`); it reads the full parquet columns so exact distances are
249    /// available without extra I/O during `search_query`.
250    pub async fn load(
251        table: &TableIdent,
252        vector_column: &str,
253        dim: u32,
254        catalog: Arc<dyn CatalogProvider>,
255        store: Arc<dyn Store>,
256        load_raw: bool,
257    ) -> AilakeResult<Self> {
258        let all_files = catalog.list_files(table, None).await?;
259        let table_meta = catalog.load_table(table).await?;
260        let metric = parse_metric(
261            table_meta
262                .properties
263                .get("ailake.vector-metric")
264                .map(String::as_str)
265                .unwrap_or("cosine"),
266        );
267
268        let mut shards = Vec::with_capacity(all_files.len());
269        for entry in all_files {
270            let file_bytes: Bytes = store.get(&entry.path).await?;
271            let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
272
273            if entry.index_status == IndexStatus::Indexing {
274                // HNSW not yet built — load raw vectors for flat scan.
275                let (_, raw_vecs) = reader.read_parquet()?;
276                shards.push(LoadedShard {
277                    entry,
278                    index: None,
279                    raw_vectors: Some(raw_vecs),
280                });
281            } else if reader.is_ailake_file() {
282                let mut index = reader.load_any_index_for_column(vector_column)?;
283                let raw_vectors = if load_raw {
284                    index.quantize_to_f16();
285                    let (_, vecs) = reader.read_parquet()?;
286                    Some(vecs)
287                } else {
288                    None
289                };
290                shards.push(LoadedShard {
291                    entry,
292                    index: Some(index),
293                    raw_vectors,
294                });
295            }
296        }
297
298        Ok(Self { shards, metric })
299    }
300
301    /// Number of loaded shards.
302    pub fn shard_count(&self) -> usize {
303        self.shards.len()
304    }
305
306    /// Search multiple queries in one call.
307    ///
308    /// For shards with raw vectors (Indexing or reranking): dispatches to GPU batch
309    /// matmul when a CUDA device is available, falling back to CPU flat scan.
310    /// For indexed shards (HNSW / IVF-PQ): rayon parallel-map over queries — graph
311    /// traversal is inherently sequential and has no GPU batch path.
312    ///
313    /// Returns one `Vec<SearchResult>` per input query, in the same order.
314    pub fn search_batch(
315        &self,
316        queries: &[Vec<f32>],
317        config: &SearchConfig,
318    ) -> Vec<Vec<SearchResult>> {
319        if queries.is_empty() {
320            return vec![];
321        }
322
323        let n_queries = queries.len();
324        let candidate_k = match config.rerank_factor {
325            Some(factor) => config.top_k * factor,
326            None => config.top_k,
327        };
328        let use_nvidia = ailake_index::hardware::detect_cuda();
329        let use_amd = ailake_index::hardware::detect_rocm();
330
331        // Accumulate per-query results across all shards.
332        let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
333
334        for shard in &self.shards {
335            if let Some(raw) = &shard.raw_vectors {
336                // Flat-scan shard — try GPU batch path (NVIDIA first, then AMD ROCm).
337                if !raw.is_empty() {
338                    let dim = raw[0].len();
339                    let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
340                    let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
341                    let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
342
343                    let gpu_batch = if use_nvidia {
344                        ailake_index::gpu::try_nvidia_search_batch(
345                            &q_refs,
346                            &row_ids,
347                            &flat,
348                            dim,
349                            self.metric,
350                            candidate_k,
351                        )
352                    } else if use_amd {
353                        ailake_index::gpu::try_rocm_search_batch(
354                            &q_refs,
355                            &row_ids,
356                            &flat,
357                            dim,
358                            self.metric,
359                            candidate_k,
360                        )
361                    } else {
362                        None
363                    };
364
365                    if let Some(batch) = gpu_batch {
366                        for (qi, results) in batch.into_iter().enumerate() {
367                            for (row_id, distance) in results {
368                                all_results[qi].push(SearchResult {
369                                    row_id,
370                                    distance,
371                                    file_path: shard.entry.path.clone(),
372                                });
373                            }
374                        }
375                        continue;
376                    }
377                }
378
379                // CPU fallback for flat scan.
380                for (qi, query) in queries.iter().enumerate() {
381                    for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
382                        all_results[qi].push(SearchResult {
383                            row_id,
384                            distance,
385                            file_path: shard.entry.path.clone(),
386                        });
387                    }
388                }
389            } else if let Some(index) = &shard.index {
390                // Indexed shard — rayon parallel-map over queries.
391                let shard_results: Vec<Vec<SearchResult>> = queries
392                    .par_iter()
393                    .map(|query| {
394                        index
395                            .search(query, candidate_k, config.ef_search)
396                            .into_iter()
397                            .map(|(row_id, distance)| SearchResult {
398                                row_id,
399                                distance,
400                                file_path: shard.entry.path.clone(),
401                            })
402                            .collect()
403                    })
404                    .collect();
405
406                for (qi, results) in shard_results.into_iter().enumerate() {
407                    all_results[qi].extend(results);
408                }
409            }
410        }
411
412        // Sort + truncate per query.
413        for results in &mut all_results {
414            results.sort_by(|a, b| {
415                a.distance
416                    .partial_cmp(&b.distance)
417                    .unwrap_or(std::cmp::Ordering::Equal)
418            });
419            results.truncate(config.top_k);
420        }
421
422        all_results
423    }
424
425    /// Search using pre-loaded indexes. No I/O — pure in-memory search.
426    pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
427        let candidate_k = match config.rerank_factor {
428            Some(factor) => config.top_k * factor,
429            None => config.top_k,
430        };
431
432        let mut all_results: Vec<SearchResult> = self
433            .shards
434            .par_iter()
435            .flat_map(|shard| {
436                // Geometric pruning per shard.
437                if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
438                    let dist = match self.metric {
439                        VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
440                            ailake_vec::cosine_distance(query, &centroid.values)
441                        }
442                        VectorMetric::Euclidean => {
443                            ailake_vec::euclidean_distance(query, &centroid.values)
444                        }
445                        VectorMetric::DotProduct => {
446                            -ailake_vec::dot_product(query, &centroid.values)
447                        }
448                    };
449                    if dist - centroid.radius > config.pruning_threshold {
450                        return vec![];
451                    }
452                }
453
454                if let Some(index) = &shard.index {
455                    // Ready shard: HNSW or IVF-PQ search (dispatched by AnyIndex).
456                    let local_results = index.search(query, candidate_k, config.ef_search);
457                    if config.rerank_factor.is_some() {
458                        if let Some(raw) = &shard.raw_vectors {
459                            local_results
460                                .into_iter()
461                                .map(|(row_id, _approx_dist)| {
462                                    let idx = row_id.as_u64() as usize;
463                                    let exact_dist = raw
464                                        .get(idx)
465                                        .map(|v| exact_distance(self.metric, query, v))
466                                        .unwrap_or(f32::INFINITY);
467                                    SearchResult {
468                                        row_id,
469                                        distance: exact_dist,
470                                        file_path: shard.entry.path.clone(),
471                                    }
472                                })
473                                .collect()
474                        } else {
475                            local_results
476                                .into_iter()
477                                .map(|(row_id, distance)| SearchResult {
478                                    row_id,
479                                    distance,
480                                    file_path: shard.entry.path.clone(),
481                                })
482                                .collect()
483                        }
484                    } else {
485                        local_results
486                            .into_iter()
487                            .map(|(row_id, distance)| SearchResult {
488                                row_id,
489                                distance,
490                                file_path: shard.entry.path.clone(),
491                            })
492                            .collect()
493                    }
494                } else if let Some(raw) = &shard.raw_vectors {
495                    // Indexing shard: exact flat scan.
496                    flat_search(raw, query, candidate_k, self.metric)
497                        .into_iter()
498                        .map(|(row_id, distance)| SearchResult {
499                            row_id,
500                            distance,
501                            file_path: shard.entry.path.clone(),
502                        })
503                        .collect()
504                } else {
505                    vec![]
506                }
507            })
508            .collect();
509
510        all_results.sort_by(|a, b| {
511            a.distance
512                .partial_cmp(&b.distance)
513                .unwrap_or(std::cmp::Ordering::Equal)
514        });
515        all_results.truncate(config.top_k);
516        all_results
517    }
518}
519
520/// Fetch full row data for a slice of search results.
521///
522/// Groups results by Parquet file, reads each file once, extracts the matching rows
523/// via `arrow_select::take`, then concatenates everything back in original top-k order
524/// with a `_distance: Float32` column appended.
525///
526/// Use this immediately after `search()` to retrieve the actual text / metadata
527/// columns (e.g. `chunk_text`, `document_title`) alongside the distance scores.
528pub async fn fetch_rows(
529    results: &[SearchResult],
530    store: Arc<dyn Store>,
531    vector_column: &str,
532    dim: u32,
533) -> AilakeResult<RecordBatch> {
534    use std::collections::HashMap;
535
536    use arrow_array::{ArrayRef, Float32Array, UInt32Array};
537    use arrow_schema::{DataType, Field, Schema};
538    use arrow_select::{concat::concat_batches, take::take};
539
540    if results.is_empty() {
541        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
542    }
543
544    // Group by file path; preserve original position for re-sorting.
545    let mut by_file: HashMap<&str, Vec<(u64, f32, usize)>> = HashMap::new();
546    for (i, r) in results.iter().enumerate() {
547        by_file
548            .entry(r.file_path.as_str())
549            .or_default()
550            .push((r.row_id.as_u64(), r.distance, i));
551    }
552
553    use arrow_array::FixedSizeListArray;
554
555    // (original_index, distance, single-row RecordBatch, decoded F32 vector)
556    let mut collected: Vec<(usize, f32, RecordBatch, Vec<f32>)> = Vec::with_capacity(results.len());
557
558    for (file_path, rows) in &by_file {
559        let bytes = store.get(file_path).await?;
560        let reader = AilakeFileReader::new(bytes, vector_column, dim);
561        let (batch, vectors) = reader.read_parquet()?;
562
563        for &(row_id, distance, pos) in rows {
564            let idx = row_id as usize;
565            if idx >= batch.num_rows() {
566                tracing::warn!(
567                    "fetch_rows: row_id {} out of bounds (file_rows={}, file={}), skipping",
568                    idx,
569                    batch.num_rows(),
570                    file_path
571                );
572                continue;
573            }
574
575            let indices = UInt32Array::from(vec![idx as u32]);
576            let row_cols: Vec<ArrayRef> = batch
577                .columns()
578                .iter()
579                .map(|col| {
580                    take(col.as_ref(), &indices, None)
581                        .map_err(|e| AilakeError::Arrow(e.to_string()))
582                })
583                .collect::<AilakeResult<Vec<_>>>()?;
584
585            let row_batch = RecordBatch::try_new(batch.schema(), row_cols)
586                .map_err(|e| AilakeError::Arrow(e.to_string()))?;
587
588            // Capture decoded F32 vector for this row (empty vec if not available).
589            let vec = vectors
590                .get(idx)
591                .cloned()
592                .unwrap_or_else(|| vec![0.0f32; dim as usize]);
593
594            collected.push((pos, distance, row_batch, vec));
595        }
596    }
597
598    if collected.is_empty() {
599        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
600    }
601
602    // Restore original top-k order from the search results slice.
603    collected.sort_by_key(|(pos, _, _, _)| *pos);
604
605    let distances: Vec<f32> = collected.iter().map(|(_, d, _, _)| *d).collect();
606    let row_batches: Vec<&RecordBatch> = collected.iter().map(|(_, _, b, _)| b).collect();
607    let base_schema = collected[0].2.schema();
608
609    let combined =
610        concat_batches(&base_schema, row_batches).map_err(|e| AilakeError::Arrow(e.to_string()))?;
611
612    // Build FixedSizeList<Float32> column with decoded vectors (F32, not raw F16 bytes).
613    let flat_vecs: Vec<f32> = collected
614        .iter()
615        .flat_map(|(_, _, _, v)| v.iter().copied())
616        .collect();
617    let item_field = Arc::new(Field::new("item", DataType::Float32, false));
618    let values_arr = Arc::new(Float32Array::from(flat_vecs)) as ArrayRef;
619    let vec_col = FixedSizeListArray::new(item_field.clone(), dim as i32, values_arr, None);
620    let vec_field = Arc::new(Field::new(
621        vector_column,
622        DataType::FixedSizeList(item_field, dim as i32),
623        false,
624    ));
625
626    // Schema: tabular cols, then decoded vector col, then _distance.
627    let mut fields: Vec<Arc<Field>> = base_schema.fields().to_vec();
628    fields.push(vec_field);
629    fields.push(Arc::new(Field::new("_distance", DataType::Float32, false)));
630    let new_schema = Arc::new(Schema::new(fields));
631
632    let mut columns: Vec<ArrayRef> = combined.columns().to_vec();
633    columns.push(Arc::new(vec_col));
634    columns.push(Arc::new(Float32Array::from(distances)));
635
636    RecordBatch::try_new(new_schema, columns).map_err(|e| AilakeError::Arrow(e.to_string()))
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642    use ailake_catalog::{HadoopCatalog, TableIdent};
643    use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
644    use ailake_store::LocalStore;
645    use arrow_array::{Int32Array, RecordBatch};
646    use arrow_schema::{DataType, Field, Schema};
647    use std::sync::Arc;
648    use tempfile::TempDir;
649
650    fn make_policy(dim: u32) -> VectorStoragePolicy {
651        VectorStoragePolicy {
652            column_name: "embedding".to_string(),
653            dim,
654            metric: VectorMetric::Cosine,
655            precision: VectorPrecision::F16,
656            pq: None,
657            keep_raw_for_reranking: true,
658            pre_normalize: false,
659            hnsw_m: None,
660            hnsw_ef_construction: None,
661            ivf_residual: false,
662            embedding_model: None,
663        }
664    }
665
666    async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
667        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
668        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
669        let table = TableIdent::new("default", "table");
670
671        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
672        let ids: Vec<i32> = (0..rows as i32).collect();
673        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
674
675        // Each row i has embedding with 1.0 at dimension i and 0 elsewhere (unit basis vectors)
676        let embeddings: Vec<Vec<f32>> = (0..rows)
677            .map(|i| {
678                let mut v = vec![0.0f32; dim];
679                v[i % dim] = 1.0;
680                v
681            })
682            .collect();
683
684        let mut writer =
685            crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
686                .await
687                .unwrap();
688        writer.write_batch(&batch, &embeddings).await.unwrap();
689        writer.commit().await.unwrap();
690    }
691
692    #[tokio::test]
693    async fn rerank_returns_correct_top_k_count() {
694        let dir = TempDir::new().unwrap();
695        let dim = 8usize;
696        write_demo_table(&dir, dim, 8).await;
697
698        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
699        let catalog: Arc<dyn CatalogProvider> =
700            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
701        let table = TableIdent::new("default", "table");
702
703        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
704        let config = SearchConfig {
705            top_k: 3,
706            ef_search: 50,
707            pruning_threshold: f32::INFINITY,
708            rerank_factor: Some(2),
709        };
710
711        let results = search(
712            &table,
713            &query,
714            config,
715            "embedding",
716            dim as u32,
717            catalog,
718            store,
719        )
720        .await
721        .unwrap();
722
723        assert_eq!(results.len(), 3);
724    }
725
726    #[tokio::test]
727    async fn rerank_nearest_is_exact_match() {
728        let dir = TempDir::new().unwrap();
729        let dim = 8usize;
730        write_demo_table(&dir, dim, 8).await;
731
732        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
733        let catalog: Arc<dyn CatalogProvider> =
734            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
735        let table = TableIdent::new("default", "table");
736
737        // Row 0 has [1,0,0,...] — cosine distance to same query is 0
738        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
739        let config = SearchConfig {
740            top_k: 1,
741            ef_search: 50,
742            pruning_threshold: f32::INFINITY,
743            rerank_factor: Some(4),
744        };
745
746        let results = search(
747            &table,
748            &query,
749            config,
750            "embedding",
751            dim as u32,
752            catalog,
753            store,
754        )
755        .await
756        .unwrap();
757
758        assert_eq!(results.len(), 1);
759        // Exact cosine distance between identical unit vectors is ~0 (F16 rounding allowed)
760        assert!(
761            results[0].distance < 1e-3,
762            "distance was {}",
763            results[0].distance
764        );
765        assert_eq!(results[0].row_id, RowId::new(0));
766    }
767
768    #[tokio::test]
769    async fn no_rerank_matches_default_behavior() {
770        let dir = TempDir::new().unwrap();
771        let dim = 4usize;
772        write_demo_table(&dir, dim, 4).await;
773
774        let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
775        let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
776        let cat_a: Arc<dyn CatalogProvider> =
777            Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
778        let cat_b: Arc<dyn CatalogProvider> =
779            Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
780        let table = TableIdent::new("default", "table");
781
782        let query = vec![1.0f32, 0.0, 0.0, 0.0];
783        let cfg_plain = SearchConfig {
784            top_k: 2,
785            ef_search: 50,
786            pruning_threshold: f32::INFINITY,
787            rerank_factor: None,
788        };
789        let cfg_rerank = SearchConfig {
790            top_k: 2,
791            ef_search: 50,
792            pruning_threshold: f32::INFINITY,
793            rerank_factor: Some(2),
794        };
795
796        let plain = search(
797            &table,
798            &query,
799            cfg_plain,
800            "embedding",
801            dim as u32,
802            cat_a,
803            store_a,
804        )
805        .await
806        .unwrap();
807        let reranked = search(
808            &table,
809            &query,
810            cfg_rerank,
811            "embedding",
812            dim as u32,
813            cat_b,
814            store_b,
815        )
816        .await
817        .unwrap();
818
819        // Both should return same top-1 result (row 0, distance ~0)
820        assert_eq!(plain[0].row_id, reranked[0].row_id);
821    }
822}