Skip to main content

ailake_query/
scanner.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::Arc;
3
4use rayon::prelude::*;
5use tracing::{debug, error};
6
7use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
8use ailake_core::{AilakeError, AilakeResult, EmbeddingModelInfo, RowId, VectorMetric};
9use ailake_file::AilakeFileReader;
10use ailake_index::AnyIndex;
11use ailake_store::Store;
12use ailake_vec::exact_distance;
13use arrow_array::RecordBatch;
14use bytes::Bytes;
15
16use crate::pruner::VectorPruner;
17
18#[derive(Debug, Clone)]
19pub struct SearchConfig {
20    pub top_k: usize,
21    pub ef_search: usize,
22    /// Maximum distance from query to file centroid edge for a file to be searched.
23    /// Files where `distance(query, centroid) - radius > pruning_threshold` are skipped.
24    /// Set to `f32::INFINITY` to disable pruning (scan all files).
25    pub pruning_threshold: f32,
26    /// When `Some(factor)`, fetch `top_k * factor` candidates from the HNSW index and
27    /// rerank them using exact F32 distances before truncating to `top_k`.
28    /// Corrects the approximation error introduced by PQ-compressed HNSW distances.
29    /// `None` (default) disables reranking.
30    pub rerank_factor: Option<usize>,
31}
32
33impl Default for SearchConfig {
34    fn default() -> Self {
35        Self {
36            top_k: 10,
37            ef_search: 50,
38            pruning_threshold: f32::INFINITY,
39            rerank_factor: None,
40        }
41    }
42}
43
44impl SearchConfig {
45    pub fn with_pruning(mut self, threshold: f32) -> Self {
46        self.pruning_threshold = threshold;
47        self
48    }
49
50    pub fn with_reranking(mut self, factor: usize) -> Self {
51        self.rerank_factor = Some(factor);
52        self
53    }
54}
55
56#[derive(Debug)]
57pub struct SearchResult {
58    pub row_id: RowId,
59    pub distance: f32,
60    pub file_path: String,
61}
62
63/// Search across all files in the latest snapshot, with geometric pruning.
64///
65/// Flow:
66/// 1. Load file list from catalog (includes centroid metadata)
67/// 2. Prune files whose centroid + radius cannot contain a result within `pruning_threshold`
68/// 3. For surviving files: load bytes, deserialize HNSW, run top-k search
69/// 4. Global merge of all per-file top-k lists, return global top-k
70pub async fn search(
71    table: &TableIdent,
72    query: &[f32],
73    config: SearchConfig,
74    vector_column: &str,
75    dim: u32,
76    catalog: Arc<dyn CatalogProvider>,
77    store: Arc<dyn Store>,
78) -> AilakeResult<Vec<SearchResult>> {
79    // Get file metadata (includes centroid info) without reading any data files
80    let all_files = catalog.list_files(table, None).await?;
81
82    // Determine vector metric from table metadata for correct distance computation
83    let table_meta = catalog.load_table(table).await?;
84
85    // Validate query dim against the column's stored dim.
86    // Primary column: use `ailake.vector-dim`. Secondary columns: use `ailake.dim-<col>`.
87    // Skip validation when the column has no stored dim (e.g. old tables written before
88    // multi-column support).
89    let primary_col = table_meta
90        .properties
91        .get("ailake.vector-column")
92        .map(String::as_str)
93        .unwrap_or("");
94    let stored_dim_key = if vector_column == primary_col {
95        "ailake.vector-dim".to_string()
96    } else {
97        format!("ailake.dim-{vector_column}")
98    };
99    if let Some(table_dim_str) = table_meta.properties.get(&stored_dim_key) {
100        if let Ok(table_dim) = table_dim_str.parse::<u32>() {
101            let query_dim = query.len() as u32;
102            if query_dim != table_dim {
103                let table_model = table_meta
104                    .properties
105                    .get(EmbeddingModelInfo::property_key())
106                    .cloned()
107                    .unwrap_or_else(|| format!("dim={}", table_dim));
108                return Err(AilakeError::ModelMismatch {
109                    table_model,
110                    table_dim,
111                    batch_model: format!("query dim={}", query_dim),
112                    batch_dim: query_dim,
113                });
114            }
115        }
116    }
117
118    // Metric: prefer per-column `ailake.metric-<col>`, fall back to primary metric.
119    let metric_key = if vector_column == primary_col {
120        "ailake.vector-metric".to_string()
121    } else {
122        format!("ailake.metric-{vector_column}")
123    };
124    let metric = parse_metric(
125        table_meta
126            .properties
127            .get(&metric_key)
128            .or_else(|| table_meta.properties.get("ailake.vector-metric"))
129            .map(String::as_str)
130            .unwrap_or("cosine"),
131    );
132
133    // Geometric pruning: skip files whose centroid is too far from the query
134    let total_files = all_files.len();
135    let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
136    debug!(
137        "ailake: geometric pruning — {}/{} files survive (threshold={})",
138        surviving_files.len(),
139        total_files,
140        config.pruning_threshold
141    );
142
143    let candidate_k = match config.rerank_factor {
144        Some(factor) => config.top_k * factor,
145        None => config.top_k,
146    };
147
148    let mut all_results: Vec<SearchResult> = Vec::new();
149
150    for file_entry in &surviving_files {
151        let file_bytes: Bytes = store.get(&file_entry.path).await?;
152        let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
153
154        if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
155            // HNSW not yet built — flat scan over raw vectors.
156            debug!(
157                "ailake: flat scan fallback for {} (index_status={:?})",
158                file_entry.path, file_entry.index_status
159            );
160            let (_, raw_vectors) = reader.read_parquet()?;
161            for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
162                all_results.push(SearchResult {
163                    row_id,
164                    distance,
165                    file_path: file_entry.path.clone(),
166                });
167            }
168            continue;
169        }
170
171        let index = reader.load_any_index_for_column(vector_column)?;
172        let local_results = index.search(query, candidate_k, config.ef_search);
173
174        if config.rerank_factor.is_some() {
175            // Read raw F32 vectors for exact distance reranking; file bytes already loaded.
176            let (_, raw_vectors) = reader.read_parquet()?;
177            for (row_id, _approx_dist) in local_results {
178                let idx = row_id.as_u64() as usize;
179                let exact_dist = match raw_vectors.get(idx) {
180                    Some(v) => exact_distance(metric, query, v),
181                    None => {
182                        error!(
183                            "ailake: invariant violated — row_id {} out of bounds \
184                             (raw_vectors.len={}, file={}); \
185                             Parquet row count and HNSW node count are out of sync; \
186                             file may be corrupt — run compaction to rebuild",
187                            idx,
188                            raw_vectors.len(),
189                            file_entry.path
190                        );
191                        f32::INFINITY
192                    }
193                };
194                all_results.push(SearchResult {
195                    row_id,
196                    distance: exact_dist,
197                    file_path: file_entry.path.clone(),
198                });
199            }
200        } else {
201            for (row_id, distance) in local_results {
202                all_results.push(SearchResult {
203                    row_id,
204                    distance,
205                    file_path: file_entry.path.clone(),
206                });
207            }
208        }
209    }
210
211    // Global merge: sort all candidates by distance, keep top-k
212    all_results.sort_by(|a, b| {
213        a.distance
214            .partial_cmp(&b.distance)
215            .unwrap_or(std::cmp::Ordering::Equal)
216    });
217    all_results.truncate(config.top_k);
218    Ok(all_results)
219}
220
221/// One query arm in a cross-modal search.
222#[derive(Debug, Clone)]
223pub struct ModalQuery<'a> {
224    /// Vector column to search (must exist in the table).
225    pub column: &'a str,
226    /// Query vector for this modality.
227    pub query: &'a [f32],
228    /// Relative weight applied in the RRF formula: `weight / (k + rank)`.
229    /// `1.0` means equal weight across all modalities.
230    pub weight: f32,
231    /// Dimensionality of this column's vectors. `0` = auto-detect from table metadata
232    /// (`ailake.dim-<column>` for secondary columns, `ailake.vector-dim` for primary).
233    pub dim: u32,
234}
235
236/// Fusion method for combining results from multiple vector columns.
237#[derive(Debug, Clone, Copy, PartialEq, Eq)]
238pub enum FusionMethod {
239    /// Reciprocal Rank Fusion: `score(d) = Σ weight_i / (k + rank_i(d))`.
240    /// `k = 60` (standard). Returned `SearchResult.distance` = `-rrf_score`
241    /// so that sort-ascending-by-distance gives the correct RRF ranking.
242    Rrf,
243}
244
245/// Cross-modal search: run independent HNSW searches across N vector columns,
246/// then fuse per-column ranked lists using Reciprocal Rank Fusion.
247///
248/// Each `ModalQuery` specifies a column name, its query vector, RRF weight, and dim.
249/// When `ModalQuery.dim == 0`, the dim is auto-detected from `ailake.dim-<col>` /
250/// `ailake.vector-dim` in table metadata.
251/// Results are de-duplicated by `(file_path, row_id)` and ranked by aggregate
252/// RRF score. `SearchResult.distance` stores `-rrf_score` (lower = better) so
253/// existing sort-ascending callers get the correct ordering.
254pub async fn search_multimodal(
255    table: &TableIdent,
256    queries: &[ModalQuery<'_>],
257    config: SearchConfig,
258    catalog: Arc<dyn CatalogProvider>,
259    store: Arc<dyn Store>,
260    fusion: FusionMethod,
261) -> AilakeResult<Vec<SearchResult>> {
262    use std::collections::HashMap;
263
264    if queries.is_empty() {
265        return Err(AilakeError::InvalidArgument(
266            "search_multimodal requires at least one ModalQuery".into(),
267        ));
268    }
269
270    // Load table metadata once for dim auto-detection and metric resolution.
271    let table_meta = catalog.load_table(table).await?;
272    let primary_col = table_meta
273        .properties
274        .get("ailake.vector-column")
275        .cloned()
276        .unwrap_or_default();
277    let primary_dim: u32 = table_meta
278        .properties
279        .get("ailake.vector-dim")
280        .and_then(|s| s.parse().ok())
281        .unwrap_or(0);
282
283    // Fetch more candidates per column so RRF has enough to fuse.
284    let per_col_k = (config.top_k * queries.len().max(2)).min(1000);
285
286    let mut per_col_results: Vec<(f32, Vec<SearchResult>)> = Vec::with_capacity(queries.len());
287    for mq in queries {
288        // Resolve dim: caller-supplied > per-column property > primary column dim.
289        let resolved_dim = if mq.dim > 0 {
290            mq.dim
291        } else if mq.column == primary_col {
292            primary_dim
293        } else {
294            table_meta
295                .properties
296                .get(&format!("ailake.dim-{}", mq.column))
297                .and_then(|s| s.parse().ok())
298                .unwrap_or(mq.query.len() as u32)
299        };
300
301        let col_config = SearchConfig {
302            top_k: per_col_k,
303            ef_search: config.ef_search,
304            pruning_threshold: config.pruning_threshold,
305            rerank_factor: config.rerank_factor,
306        };
307        let results = search(
308            table,
309            mq.query,
310            col_config,
311            mq.column,
312            resolved_dim,
313            catalog.clone(),
314            store.clone(),
315        )
316        .await?;
317        per_col_results.push((mq.weight, results));
318    }
319
320    // RRF fusion: accumulate score per (file_path, row_id).
321    const K: f32 = 60.0;
322    let mut scores: HashMap<(String, u64), f32> = HashMap::new();
323
324    for (weight, results) in &per_col_results {
325        for (rank, r) in results.iter().enumerate() {
326            let key = (r.file_path.clone(), r.row_id.as_u64());
327            let rrf = weight / (K + rank as f32 + 1.0);
328            *scores.entry(key).or_insert(0.0) += rrf;
329        }
330    }
331
332    // Build SearchResult list sorted by descending RRF score.
333    // Store `-rrf_score` as `.distance` so callers sorting ascending get correct order.
334    let all_files = catalog.list_files(table, None).await?;
335    let _ = all_files; // centroid not needed for fusion — just need file_path+row_id
336
337    // Collect unique candidates: prefer the row's appearance in the first column's results.
338    let mut seen: HashMap<(String, u64), f32> = HashMap::new();
339    for (_, results) in &per_col_results {
340        for r in results {
341            let key = (r.file_path.clone(), r.row_id.as_u64());
342            let rrf_score = *scores.get(&key).unwrap_or(&0.0);
343            seen.entry(key).or_insert(rrf_score);
344        }
345    }
346
347    let mut fused: Vec<SearchResult> = seen
348        .into_iter()
349        .map(|((file_path, row_id_u64), rrf_score)| SearchResult {
350            row_id: RowId::new(row_id_u64),
351            distance: -rrf_score,
352            file_path,
353        })
354        .collect();
355
356    fused.sort_by(|a, b| {
357        a.distance
358            .partial_cmp(&b.distance)
359            .unwrap_or(std::cmp::Ordering::Equal)
360    });
361    fused.truncate(config.top_k);
362
363    let _ = fusion; // only RRF implemented; enum is extensible
364
365    Ok(fused)
366}
367
368/// Brute-force top-k search over raw vectors. Used for Indexing shards.
369fn flat_search(
370    raw: &[Vec<f32>],
371    query: &[f32],
372    top_k: usize,
373    metric: VectorMetric,
374) -> Vec<(RowId, f32)> {
375    let mut results: Vec<(RowId, f32)> = raw
376        .iter()
377        .enumerate()
378        .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
379        .collect();
380    results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
381    results.truncate(top_k);
382    results
383}
384
385fn parse_metric(s: &str) -> VectorMetric {
386    match s {
387        "euclidean" => VectorMetric::Euclidean,
388        "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
389        _ => VectorMetric::Cosine,
390    }
391}
392
393/// Pre-loaded search session: all HNSW indexes loaded into memory once.
394///
395/// Useful for benchmarks and servers that issue many queries against the same
396/// snapshot. Avoids re-loading and re-deserializing indexes on every call.
397pub struct SearchSession {
398    shards: Vec<LoadedShard>,
399    metric: VectorMetric,
400}
401
402struct LoadedShard {
403    entry: DataFileEntry,
404    /// None when the shard is still being indexed (IndexStatus::Indexing).
405    index: Option<AnyIndex>,
406    /// Raw F32 vectors: always present for Indexing shards (flat scan), optionally
407    /// present for Ready shards when `load_raw = true` (reranking).
408    raw_vectors: Option<Vec<Vec<f32>>>,
409}
410
411impl SearchSession {
412    /// Load all indexes for the latest snapshot into memory.
413    ///
414    /// Pass `load_raw = true` when reranking will be used (`rerank_factor` is
415    /// `Some`); it reads the full parquet columns so exact distances are
416    /// available without extra I/O during `search_query`.
417    pub async fn load(
418        table: &TableIdent,
419        vector_column: &str,
420        dim: u32,
421        catalog: Arc<dyn CatalogProvider>,
422        store: Arc<dyn Store>,
423        load_raw: bool,
424    ) -> AilakeResult<Self> {
425        let all_files = catalog.list_files(table, None).await?;
426        let table_meta = catalog.load_table(table).await?;
427        let metric = parse_metric(
428            table_meta
429                .properties
430                .get("ailake.vector-metric")
431                .map(String::as_str)
432                .unwrap_or("cosine"),
433        );
434
435        let mut shards = Vec::with_capacity(all_files.len());
436        for entry in all_files {
437            let file_bytes: Bytes = store.get(&entry.path).await?;
438            let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
439
440            if entry.index_status == IndexStatus::Indexing {
441                // HNSW not yet built — load raw vectors for flat scan.
442                let (_, raw_vecs) = reader.read_parquet()?;
443                shards.push(LoadedShard {
444                    entry,
445                    index: None,
446                    raw_vectors: Some(raw_vecs),
447                });
448            } else if reader.is_ailake_file() {
449                let mut index = reader.load_any_index_for_column(vector_column)?;
450                let raw_vectors = if load_raw {
451                    index.quantize_to_f16();
452                    let (_, vecs) = reader.read_parquet()?;
453                    Some(vecs)
454                } else {
455                    None
456                };
457                shards.push(LoadedShard {
458                    entry,
459                    index: Some(index),
460                    raw_vectors,
461                });
462            }
463        }
464
465        Ok(Self { shards, metric })
466    }
467
468    /// Number of loaded shards.
469    pub fn shard_count(&self) -> usize {
470        self.shards.len()
471    }
472
473    /// Search multiple queries in one call.
474    ///
475    /// For shards with raw vectors (Indexing or reranking): dispatches to GPU batch
476    /// matmul when a CUDA device is available, falling back to CPU flat scan.
477    /// For indexed shards (HNSW / IVF-PQ): rayon parallel-map over queries — graph
478    /// traversal is inherently sequential and has no GPU batch path.
479    ///
480    /// Returns one `Vec<SearchResult>` per input query, in the same order.
481    pub fn search_batch(
482        &self,
483        queries: &[Vec<f32>],
484        config: &SearchConfig,
485    ) -> Vec<Vec<SearchResult>> {
486        if queries.is_empty() {
487            return vec![];
488        }
489
490        let n_queries = queries.len();
491        let candidate_k = match config.rerank_factor {
492            Some(factor) => config.top_k * factor,
493            None => config.top_k,
494        };
495        let use_nvidia = ailake_index::hardware::detect_cuda();
496        let use_amd = ailake_index::hardware::detect_rocm();
497
498        // Accumulate per-query results across all shards.
499        let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
500
501        for shard in &self.shards {
502            if let Some(raw) = &shard.raw_vectors {
503                // Flat-scan shard — try GPU batch path (NVIDIA first, then AMD ROCm).
504                if !raw.is_empty() {
505                    let dim = raw[0].len();
506                    let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
507                    let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
508                    let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
509
510                    let gpu_batch = if use_nvidia {
511                        ailake_index::gpu::try_nvidia_search_batch(
512                            &q_refs,
513                            &row_ids,
514                            &flat,
515                            dim,
516                            self.metric,
517                            candidate_k,
518                        )
519                    } else if use_amd {
520                        ailake_index::gpu::try_rocm_search_batch(
521                            &q_refs,
522                            &row_ids,
523                            &flat,
524                            dim,
525                            self.metric,
526                            candidate_k,
527                        )
528                    } else {
529                        None
530                    };
531
532                    if let Some(batch) = gpu_batch {
533                        for (qi, results) in batch.into_iter().enumerate() {
534                            for (row_id, distance) in results {
535                                all_results[qi].push(SearchResult {
536                                    row_id,
537                                    distance,
538                                    file_path: shard.entry.path.clone(),
539                                });
540                            }
541                        }
542                        continue;
543                    }
544                }
545
546                // CPU fallback for flat scan.
547                for (qi, query) in queries.iter().enumerate() {
548                    for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
549                        all_results[qi].push(SearchResult {
550                            row_id,
551                            distance,
552                            file_path: shard.entry.path.clone(),
553                        });
554                    }
555                }
556            } else if let Some(index) = &shard.index {
557                // Indexed shard — rayon parallel-map over queries.
558                let shard_results: Vec<Vec<SearchResult>> = queries
559                    .par_iter()
560                    .map(|query| {
561                        index
562                            .search(query, candidate_k, config.ef_search)
563                            .into_iter()
564                            .map(|(row_id, distance)| SearchResult {
565                                row_id,
566                                distance,
567                                file_path: shard.entry.path.clone(),
568                            })
569                            .collect()
570                    })
571                    .collect();
572
573                for (qi, results) in shard_results.into_iter().enumerate() {
574                    all_results[qi].extend(results);
575                }
576            }
577        }
578
579        // Sort + truncate per query.
580        for results in &mut all_results {
581            results.sort_by(|a, b| {
582                a.distance
583                    .partial_cmp(&b.distance)
584                    .unwrap_or(std::cmp::Ordering::Equal)
585            });
586            results.truncate(config.top_k);
587        }
588
589        all_results
590    }
591
592    /// Search using pre-loaded indexes. No I/O — pure in-memory search.
593    pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
594        let candidate_k = match config.rerank_factor {
595            Some(factor) => config.top_k * factor,
596            None => config.top_k,
597        };
598
599        let mut all_results: Vec<SearchResult> = self
600            .shards
601            .par_iter()
602            .flat_map(|shard| {
603                // Geometric pruning per shard.
604                if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
605                    let dist = match self.metric {
606                        VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
607                            ailake_vec::cosine_distance(query, &centroid.values)
608                        }
609                        VectorMetric::Euclidean => {
610                            ailake_vec::euclidean_distance(query, &centroid.values)
611                        }
612                        VectorMetric::DotProduct => {
613                            -ailake_vec::dot_product(query, &centroid.values)
614                        }
615                    };
616                    if dist - centroid.radius > config.pruning_threshold {
617                        return vec![];
618                    }
619                }
620
621                if let Some(index) = &shard.index {
622                    // Ready shard: HNSW or IVF-PQ search (dispatched by AnyIndex).
623                    let local_results = index.search(query, candidate_k, config.ef_search);
624                    if config.rerank_factor.is_some() {
625                        if let Some(raw) = &shard.raw_vectors {
626                            local_results
627                                .into_iter()
628                                .map(|(row_id, _approx_dist)| {
629                                    let idx = row_id.as_u64() as usize;
630                                    let exact_dist = raw
631                                        .get(idx)
632                                        .map(|v| exact_distance(self.metric, query, v))
633                                        .unwrap_or(f32::INFINITY);
634                                    SearchResult {
635                                        row_id,
636                                        distance: exact_dist,
637                                        file_path: shard.entry.path.clone(),
638                                    }
639                                })
640                                .collect()
641                        } else {
642                            local_results
643                                .into_iter()
644                                .map(|(row_id, distance)| SearchResult {
645                                    row_id,
646                                    distance,
647                                    file_path: shard.entry.path.clone(),
648                                })
649                                .collect()
650                        }
651                    } else {
652                        local_results
653                            .into_iter()
654                            .map(|(row_id, distance)| SearchResult {
655                                row_id,
656                                distance,
657                                file_path: shard.entry.path.clone(),
658                            })
659                            .collect()
660                    }
661                } else if let Some(raw) = &shard.raw_vectors {
662                    // Indexing shard: exact flat scan.
663                    flat_search(raw, query, candidate_k, self.metric)
664                        .into_iter()
665                        .map(|(row_id, distance)| SearchResult {
666                            row_id,
667                            distance,
668                            file_path: shard.entry.path.clone(),
669                        })
670                        .collect()
671                } else {
672                    vec![]
673                }
674            })
675            .collect();
676
677        all_results.sort_by(|a, b| {
678            a.distance
679                .partial_cmp(&b.distance)
680                .unwrap_or(std::cmp::Ordering::Equal)
681        });
682        all_results.truncate(config.top_k);
683        all_results
684    }
685}
686
687/// Fetch full row data for a slice of search results.
688///
689/// Groups results by Parquet file, reads each file once, extracts the matching rows
690/// via `arrow_select::take`, then concatenates everything back in original top-k order
691/// with a `_distance: Float32` column appended.
692///
693/// Use this immediately after `search()` to retrieve the actual text / metadata
694/// columns (e.g. `chunk_text`, `document_title`) alongside the distance scores.
695pub async fn fetch_rows(
696    results: &[SearchResult],
697    store: Arc<dyn Store>,
698    vector_column: &str,
699    dim: u32,
700) -> AilakeResult<RecordBatch> {
701    use std::collections::HashMap;
702
703    use arrow_array::{ArrayRef, Float32Array, UInt32Array};
704    use arrow_schema::{DataType, Field, Schema};
705    use arrow_select::{concat::concat_batches, take::take};
706
707    if results.is_empty() {
708        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
709    }
710
711    // Group by file path; preserve original position for re-sorting.
712    let mut by_file: HashMap<&str, Vec<(u64, f32, usize)>> = HashMap::new();
713    for (i, r) in results.iter().enumerate() {
714        by_file
715            .entry(r.file_path.as_str())
716            .or_default()
717            .push((r.row_id.as_u64(), r.distance, i));
718    }
719
720    use arrow_array::FixedSizeListArray;
721
722    // (original_index, distance, single-row RecordBatch, decoded F32 vector)
723    let mut collected: Vec<(usize, f32, RecordBatch, Vec<f32>)> = Vec::with_capacity(results.len());
724
725    for (file_path, rows) in &by_file {
726        let bytes = store.get(file_path).await?;
727        let reader = AilakeFileReader::new(bytes, vector_column, dim);
728        let (batch, vectors) = reader.read_parquet()?;
729
730        for &(row_id, distance, pos) in rows {
731            let idx = row_id as usize;
732            if idx >= batch.num_rows() {
733                tracing::warn!(
734                    "fetch_rows: row_id {} out of bounds (file_rows={}, file={}), skipping",
735                    idx,
736                    batch.num_rows(),
737                    file_path
738                );
739                continue;
740            }
741
742            let indices = UInt32Array::from(vec![idx as u32]);
743            let row_cols: Vec<ArrayRef> = batch
744                .columns()
745                .iter()
746                .map(|col| {
747                    take(col.as_ref(), &indices, None)
748                        .map_err(|e| AilakeError::Arrow(e.to_string()))
749                })
750                .collect::<AilakeResult<Vec<_>>>()?;
751
752            let row_batch = RecordBatch::try_new(batch.schema(), row_cols)
753                .map_err(|e| AilakeError::Arrow(e.to_string()))?;
754
755            // Capture decoded F32 vector for this row (empty vec if not available).
756            let vec = vectors
757                .get(idx)
758                .cloned()
759                .unwrap_or_else(|| vec![0.0f32; dim as usize]);
760
761            collected.push((pos, distance, row_batch, vec));
762        }
763    }
764
765    if collected.is_empty() {
766        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
767    }
768
769    // Restore original top-k order from the search results slice.
770    collected.sort_by_key(|(pos, _, _, _)| *pos);
771
772    let distances: Vec<f32> = collected.iter().map(|(_, d, _, _)| *d).collect();
773    let row_batches: Vec<&RecordBatch> = collected.iter().map(|(_, _, b, _)| b).collect();
774    let base_schema = collected[0].2.schema();
775
776    let combined =
777        concat_batches(&base_schema, row_batches).map_err(|e| AilakeError::Arrow(e.to_string()))?;
778
779    // Build FixedSizeList<Float32> column with decoded vectors (F32, not raw F16 bytes).
780    let flat_vecs: Vec<f32> = collected
781        .iter()
782        .flat_map(|(_, _, _, v)| v.iter().copied())
783        .collect();
784    let item_field = Arc::new(Field::new("item", DataType::Float32, false));
785    let values_arr = Arc::new(Float32Array::from(flat_vecs)) as ArrayRef;
786    let vec_col = FixedSizeListArray::new(item_field.clone(), dim as i32, values_arr, None);
787    let vec_field = Arc::new(Field::new(
788        vector_column,
789        DataType::FixedSizeList(item_field, dim as i32),
790        false,
791    ));
792
793    // Schema: tabular cols, then decoded vector col, then _distance.
794    let mut fields: Vec<Arc<Field>> = base_schema.fields().to_vec();
795    fields.push(vec_field);
796    fields.push(Arc::new(Field::new("_distance", DataType::Float32, false)));
797    let new_schema = Arc::new(Schema::new(fields));
798
799    let mut columns: Vec<ArrayRef> = combined.columns().to_vec();
800    columns.push(Arc::new(vec_col));
801    columns.push(Arc::new(Float32Array::from(distances)));
802
803    RecordBatch::try_new(new_schema, columns).map_err(|e| AilakeError::Arrow(e.to_string()))
804}
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809    use crate::writer::MultiVectorBatch;
810    use ailake_catalog::{HadoopCatalog, TableIdent};
811    use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
812    use ailake_store::LocalStore;
813    use arrow_array::{Int32Array, RecordBatch};
814    use arrow_schema::{DataType, Field, Schema};
815    use std::sync::Arc;
816    use tempfile::TempDir;
817
818    fn make_policy(dim: u32) -> VectorStoragePolicy {
819        VectorStoragePolicy {
820            column_name: "embedding".to_string(),
821            dim,
822            metric: VectorMetric::Cosine,
823            precision: VectorPrecision::F16,
824            pq: None,
825            keep_raw_for_reranking: true,
826            pre_normalize: false,
827            hnsw_m: None,
828            hnsw_ef_construction: None,
829            ivf_residual: false,
830            embedding_model: None,
831            modality: None,
832        }
833    }
834
835    async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
836        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
837        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
838        let table = TableIdent::new("default", "table");
839
840        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
841        let ids: Vec<i32> = (0..rows as i32).collect();
842        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
843
844        // Each row i has embedding with 1.0 at dimension i and 0 elsewhere (unit basis vectors)
845        let embeddings: Vec<Vec<f32>> = (0..rows)
846            .map(|i| {
847                let mut v = vec![0.0f32; dim];
848                v[i % dim] = 1.0;
849                v
850            })
851            .collect();
852
853        let mut writer =
854            crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
855                .await
856                .unwrap();
857        writer.write_batch(&batch, &embeddings).await.unwrap();
858        writer.commit().await.unwrap();
859    }
860
861    #[tokio::test]
862    async fn rerank_returns_correct_top_k_count() {
863        let dir = TempDir::new().unwrap();
864        let dim = 8usize;
865        write_demo_table(&dir, dim, 8).await;
866
867        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
868        let catalog: Arc<dyn CatalogProvider> =
869            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
870        let table = TableIdent::new("default", "table");
871
872        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
873        let config = SearchConfig {
874            top_k: 3,
875            ef_search: 50,
876            pruning_threshold: f32::INFINITY,
877            rerank_factor: Some(2),
878        };
879
880        let results = search(
881            &table,
882            &query,
883            config,
884            "embedding",
885            dim as u32,
886            catalog,
887            store,
888        )
889        .await
890        .unwrap();
891
892        assert_eq!(results.len(), 3);
893    }
894
895    #[tokio::test]
896    async fn rerank_nearest_is_exact_match() {
897        let dir = TempDir::new().unwrap();
898        let dim = 8usize;
899        write_demo_table(&dir, dim, 8).await;
900
901        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
902        let catalog: Arc<dyn CatalogProvider> =
903            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
904        let table = TableIdent::new("default", "table");
905
906        // Row 0 has [1,0,0,...] — cosine distance to same query is 0
907        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
908        let config = SearchConfig {
909            top_k: 1,
910            ef_search: 50,
911            pruning_threshold: f32::INFINITY,
912            rerank_factor: Some(4),
913        };
914
915        let results = search(
916            &table,
917            &query,
918            config,
919            "embedding",
920            dim as u32,
921            catalog,
922            store,
923        )
924        .await
925        .unwrap();
926
927        assert_eq!(results.len(), 1);
928        // Exact cosine distance between identical unit vectors is ~0 (F16 rounding allowed)
929        assert!(
930            results[0].distance < 1e-3,
931            "distance was {}",
932            results[0].distance
933        );
934        assert_eq!(results[0].row_id, RowId::new(0));
935    }
936
937    #[tokio::test]
938    async fn no_rerank_matches_default_behavior() {
939        let dir = TempDir::new().unwrap();
940        let dim = 4usize;
941        write_demo_table(&dir, dim, 4).await;
942
943        let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
944        let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
945        let cat_a: Arc<dyn CatalogProvider> =
946            Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
947        let cat_b: Arc<dyn CatalogProvider> =
948            Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
949        let table = TableIdent::new("default", "table");
950
951        let query = vec![1.0f32, 0.0, 0.0, 0.0];
952        let cfg_plain = SearchConfig {
953            top_k: 2,
954            ef_search: 50,
955            pruning_threshold: f32::INFINITY,
956            rerank_factor: None,
957        };
958        let cfg_rerank = SearchConfig {
959            top_k: 2,
960            ef_search: 50,
961            pruning_threshold: f32::INFINITY,
962            rerank_factor: Some(2),
963        };
964
965        let plain = search(
966            &table,
967            &query,
968            cfg_plain,
969            "embedding",
970            dim as u32,
971            cat_a,
972            store_a,
973        )
974        .await
975        .unwrap();
976        let reranked = search(
977            &table,
978            &query,
979            cfg_rerank,
980            "embedding",
981            dim as u32,
982            cat_b,
983            store_b,
984        )
985        .await
986        .unwrap();
987
988        // Both should return same top-1 result (row 0, distance ~0)
989        assert_eq!(plain[0].row_id, reranked[0].row_id);
990    }
991
992    #[tokio::test]
993    async fn multimodal_rrf_returns_top_k() {
994        let dir = TempDir::new().unwrap();
995        let dim = 4usize;
996        write_demo_table(&dir, dim, 4).await;
997
998        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
999        let catalog: Arc<dyn CatalogProvider> =
1000            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
1001        let table = TableIdent::new("default", "table");
1002
1003        // Two modal queries using the same column (single-column table).
1004        // Different queries to exercise RRF merging.
1005        let q1 = vec![1.0f32, 0.0, 0.0, 0.0];
1006        let q2 = vec![0.0f32, 1.0, 0.0, 0.0];
1007
1008        let queries = vec![
1009            ModalQuery {
1010                column: "embedding",
1011                query: &q1,
1012                weight: 0.7,
1013                dim: dim as u32,
1014            },
1015            ModalQuery {
1016                column: "embedding",
1017                query: &q2,
1018                weight: 0.3,
1019                dim: dim as u32,
1020            },
1021        ];
1022
1023        let config = SearchConfig {
1024            top_k: 2,
1025            ef_search: 50,
1026            pruning_threshold: f32::INFINITY,
1027            rerank_factor: None,
1028        };
1029
1030        let results =
1031            search_multimodal(&table, &queries, config, catalog, store, FusionMethod::Rrf)
1032                .await
1033                .unwrap();
1034
1035        assert_eq!(results.len(), 2);
1036        // RRF score stored as -distance; all should be negative
1037        assert!(results[0].distance <= 0.0);
1038        // Top result should be one of rows 0 or 1 (nearest to q1 or q2)
1039        assert!(results[0].row_id.as_u64() < 4);
1040    }
1041
1042    /// True cross-modal test: two columns with DIFFERENT dims (4 + 2).
1043    /// Verifies that search_multimodal correctly routes to each column's HNSW
1044    /// and that the dim validation in search() handles secondary columns.
1045    #[tokio::test]
1046    async fn multimodal_rrf_cross_modal_different_dims() {
1047        let dir = TempDir::new().unwrap();
1048        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
1049        let catalog: Arc<dyn CatalogProvider> =
1050            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
1051        let table = TableIdent::new("default", "table");
1052
1053        // Write a 2-column table: "embedding" dim=4, "img_embedding" dim=2
1054        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1055        let rows = 4usize;
1056        let ids: Vec<i32> = (0..rows as i32).collect();
1057        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
1058
1059        let text_embs: Vec<Vec<f32>> = (0..rows)
1060            .map(|i| {
1061                let mut v = vec![0.0f32; 4];
1062                v[i % 4] = 1.0;
1063                v
1064            })
1065            .collect();
1066        let img_embs: Vec<Vec<f32>> = (0..rows)
1067            .map(|i| {
1068                let mut v = vec![0.0f32; 2];
1069                v[i % 2] = 1.0;
1070                v
1071            })
1072            .collect();
1073
1074        let text_policy = make_policy(4);
1075        let img_policy = VectorStoragePolicy {
1076            column_name: "img_embedding".to_string(),
1077            dim: 2,
1078            metric: VectorMetric::Cosine,
1079            precision: VectorPrecision::F16,
1080            pq: None,
1081            keep_raw_for_reranking: true,
1082            pre_normalize: false,
1083            hnsw_m: None,
1084            hnsw_ef_construction: None,
1085            ivf_residual: false,
1086            embedding_model: None,
1087            modality: None,
1088        };
1089
1090        let mut writer = crate::TableWriter::create_or_open(
1091            catalog.clone(),
1092            store.clone(),
1093            text_policy,
1094            table.clone(),
1095        )
1096        .await
1097        .unwrap();
1098
1099        let batches = [
1100            MultiVectorBatch {
1101                policy: make_policy(4),
1102                embeddings: &text_embs,
1103            },
1104            MultiVectorBatch {
1105                policy: img_policy,
1106                embeddings: &img_embs,
1107            },
1108        ];
1109        writer.write_batch_multi(&batch, &batches).await.unwrap();
1110        writer.commit().await.unwrap();
1111
1112        // Cross-modal search: text query (dim=4) + image query (dim=2).
1113        let q_text = vec![1.0f32, 0.0, 0.0, 0.0];
1114        let q_img = vec![1.0f32, 0.0];
1115
1116        let queries = vec![
1117            ModalQuery {
1118                column: "embedding",
1119                query: &q_text,
1120                weight: 0.6,
1121                dim: 4,
1122            },
1123            ModalQuery {
1124                column: "img_embedding",
1125                query: &q_img,
1126                weight: 0.4,
1127                dim: 2,
1128            },
1129        ];
1130        let config = SearchConfig {
1131            top_k: 2,
1132            ef_search: 50,
1133            pruning_threshold: f32::INFINITY,
1134            rerank_factor: None,
1135        };
1136
1137        let results =
1138            search_multimodal(&table, &queries, config, catalog, store, FusionMethod::Rrf)
1139                .await
1140                .unwrap();
1141
1142        assert!(!results.is_empty(), "should return results");
1143        assert!(results[0].distance <= 0.0, "distance is -rrf_score");
1144        // Row 0 is nearest to both q_text=[1,0,0,0] and q_img=[1,0]
1145        assert_eq!(results[0].row_id.as_u64(), 0, "row 0 should rank first");
1146    }
1147}