Skip to main content

ailake_query/
scanner.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2use std::sync::Arc;
3
4use rayon::prelude::*;
5use tracing::{debug, error, warn};
6
7use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
8use ailake_core::{AilakeError, AilakeResult, EmbeddingModelInfo, RowId, VectorMetric};
9use ailake_file::AilakeFileReader;
10use ailake_index::AnyIndex;
11use ailake_store::Store;
12use ailake_vec::exact_distance;
13use arrow_array::{Array, RecordBatch};
14use bytes::Bytes;
15
16use crate::equality_delete::EqualityDeleteFilter;
17use crate::pruner::{BloomPruner, VectorPruner};
18use crate::schema_filler::SchemaFiller;
19
20/// Injectable per-result scoring function for hybrid ranking.
21///
22/// Called after HNSW retrieval with the HNSW distance and a single-row
23/// `RecordBatch` containing all Parquet columns for that result. Returns a
24/// replacement score (lower = better rank, same convention as distance).
25///
26/// Typical use: combine HNSW distance with recency and importance signals
27/// from the `episodic_columns` for agent memory tables:
28///
29/// ```rust,no_run
30/// use ailake_core::{hybrid_score, episodic_columns};
31/// use ailake_query::scanner::ScoreFn;
32/// use arrow_array::{RecordBatch, cast::AsArray};
33/// use arrow_array::types::Float32Type;
34///
35/// let score_fn = ScoreFn::new(|distance, row| {
36///     let recency = row
37///         .column_by_name(episodic_columns::RECENCY_WEIGHT)
38///         .and_then(|c| c.as_primitive_opt::<Float32Type>())
39///         .and_then(|a| a.iter().next().flatten())
40///         .unwrap_or(1.0);
41///     let importance = row
42///         .column_by_name(episodic_columns::IMPORTANCE_SCORE)
43///         .and_then(|c| c.as_primitive_opt::<Float32Type>())
44///         .and_then(|a| a.iter().next().flatten())
45///         .unwrap_or(1.0);
46///     hybrid_score(distance, recency, importance)
47/// });
48/// ```
49#[allow(clippy::type_complexity)]
50pub struct ScoreFn(pub std::sync::Arc<dyn Fn(f32, &RecordBatch) -> f32 + Send + Sync>);
51
52impl ScoreFn {
53    pub fn new(f: impl Fn(f32, &RecordBatch) -> f32 + Send + Sync + 'static) -> Self {
54        Self(std::sync::Arc::new(f))
55    }
56
57    #[inline]
58    pub fn call(&self, distance: f32, row: &RecordBatch) -> f32 {
59        (self.0)(distance, row)
60    }
61}
62
63impl Clone for ScoreFn {
64    fn clone(&self) -> Self {
65        Self(std::sync::Arc::clone(&self.0))
66    }
67}
68
69impl std::fmt::Debug for ScoreFn {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.write_str("ScoreFn(<fn>)")
72    }
73}
74
75#[derive(Debug, Clone)]
76pub struct SearchConfig {
77    pub top_k: usize,
78    pub ef_search: usize,
79    /// Maximum distance from query to file centroid edge for a file to be searched.
80    /// Files where `distance(query, centroid) - radius > pruning_threshold` are skipped.
81    /// Set to `f32::INFINITY` to disable pruning (scan all files).
82    pub pruning_threshold: f32,
83    /// When `Some(factor)`, fetch `top_k * factor` candidates from the HNSW index and
84    /// rerank them using exact F32 distances before truncating to `top_k`.
85    /// Corrects the approximation error introduced by PQ-compressed HNSW distances.
86    /// `None` (default) disables reranking.
87    pub rerank_factor: Option<usize>,
88    /// Hybrid BM25+vector search configuration.
89    ///
90    /// When set, the pipeline loads global IDF stats from the table's BM25 stats file,
91    /// fetches a larger candidate pool from HNSW (`candidate_pool` or `10 * top_k`),
92    /// scores each candidate with BM25 against `query_text`, then fuses vector distance
93    /// and BM25 score via RRF (default) or linear combination.
94    ///
95    /// The BM25 stats file (`metadata/ailake_bm25_stats.bin`) is populated automatically
96    /// by `TableWriter` when `bm25_text_column` is configured. If absent, pure vector
97    /// distances are used (BM25 scores default to 0).
98    pub hybrid: Option<crate::bm25::HybridConfig>,
99    /// Optional scoring function for hybrid ranking.
100    ///
101    /// When set, the search pipeline reads the Parquet row for each HNSW
102    /// candidate and calls `score_fn(distance, &single_row_batch)`. The
103    /// returned value replaces `distance` in `SearchResult` and determines
104    /// final ranking (lower = better).
105    ///
106    /// If `rerank_factor` is also set, `score_fn` receives the exact
107    /// (non-approximated) distance from the reranking step.
108    ///
109    /// Use `ScoreFn::new(|d, row| ...)` to construct. See `ScoreFn` docs
110    /// for an example using `hybrid_score` with episodic memory columns.
111    pub score_fn: Option<ScoreFn>,
112    /// Partition filter: only search files whose `DataFileEntry::partition_value`
113    /// matches this string. `None` searches all files (no partition pruning).
114    /// Set to `agent_id` in Agent.recall() for per-agent isolated search.
115    pub partition_filter: Option<String>,
116}
117
118impl Default for SearchConfig {
119    fn default() -> Self {
120        Self {
121            top_k: 10,
122            ef_search: 50,
123            pruning_threshold: f32::INFINITY,
124            rerank_factor: None,
125            score_fn: None,
126            partition_filter: None,
127            hybrid: None,
128        }
129    }
130}
131
132impl SearchConfig {
133    pub fn with_pruning(mut self, threshold: f32) -> Self {
134        self.pruning_threshold = threshold;
135        self
136    }
137
138    pub fn with_reranking(mut self, factor: usize) -> Self {
139        self.rerank_factor = Some(factor);
140        self
141    }
142
143    pub fn with_score_fn(
144        mut self,
145        f: impl Fn(f32, &RecordBatch) -> f32 + Send + Sync + 'static,
146    ) -> Self {
147        self.score_fn = Some(ScoreFn::new(f));
148        self
149    }
150
151    pub fn with_hybrid(mut self, cfg: crate::bm25::HybridConfig) -> Self {
152        self.hybrid = Some(cfg);
153        self
154    }
155}
156
157#[derive(Debug)]
158pub struct SearchResult {
159    pub row_id: RowId,
160    pub distance: f32,
161    pub file_path: String,
162}
163
164/// Search across all files in the latest snapshot, with geometric pruning.
165///
166/// Flow:
167/// 1. Load file list from catalog (includes centroid metadata)
168/// 2. Prune files whose centroid + radius cannot contain a result within `pruning_threshold`
169/// 3. For surviving files: load bytes, deserialize HNSW, run top-k search
170/// 4. Global merge of all per-file top-k lists, return global top-k
171pub async fn search(
172    table: &TableIdent,
173    query: &[f32],
174    config: SearchConfig,
175    vector_column: &str,
176    dim: u32,
177    catalog: Arc<dyn CatalogProvider>,
178    store: Arc<dyn Store>,
179) -> AilakeResult<Vec<SearchResult>> {
180    // Get file metadata (includes centroid info) without reading any data files
181    let all_files = catalog.list_files(table, None).await?;
182
183    // Determine vector metric from table metadata for correct distance computation
184    let table_meta = catalog.load_table(table).await?;
185
186    // Validate query dim against the column's stored dim.
187    // Primary column: use `ailake.vector-dim`. Secondary columns: use `ailake.dim-<col>`.
188    // Skip validation when the column has no stored dim (e.g. old tables written before
189    // multi-column support).
190    let primary_col = table_meta
191        .properties
192        .get("ailake.vector-column")
193        .map(String::as_str)
194        .unwrap_or("");
195    let stored_dim_key = if vector_column == primary_col {
196        "ailake.vector-dim".to_string()
197    } else {
198        format!("ailake.dim-{vector_column}")
199    };
200    if let Some(table_dim_str) = table_meta.properties.get(&stored_dim_key) {
201        if let Ok(table_dim) = table_dim_str.parse::<u32>() {
202            let query_dim = query.len() as u32;
203            if query_dim != table_dim {
204                let table_model = table_meta
205                    .properties
206                    .get(EmbeddingModelInfo::property_key())
207                    .cloned()
208                    .unwrap_or_else(|| format!("dim={}", table_dim));
209                return Err(AilakeError::ModelMismatch {
210                    table_model,
211                    table_dim,
212                    batch_model: format!("query dim={}", query_dim),
213                    batch_dim: query_dim,
214                });
215            }
216        }
217    }
218
219    // Metric: prefer per-column `ailake.metric-<col>`, fall back to primary metric.
220    let metric_key = if vector_column == primary_col {
221        "ailake.vector-metric".to_string()
222    } else {
223        format!("ailake.metric-{vector_column}")
224    };
225    let metric = parse_metric(
226        table_meta
227            .properties
228            .get(&metric_key)
229            .or_else(|| table_meta.properties.get("ailake.vector-metric"))
230            .map(String::as_str)
231            .unwrap_or("cosine"),
232    );
233
234    // Partition pruning: skip files not belonging to the requested partition value.
235    let all_files = if let Some(ref pv) = config.partition_filter {
236        let before = all_files.len();
237        let filtered: Vec<_> = all_files
238            .into_iter()
239            .filter(|f| f.partition_value.as_deref() == Some(pv.as_str()))
240            .collect();
241        debug!(
242            "ailake: partition pruning '{}' — {}/{} files survive",
243            pv,
244            filtered.len(),
245            before
246        );
247        filtered
248    } else {
249        all_files
250    };
251
252    // Geometric pruning: skip files whose centroid is too far from the query
253    let total_files = all_files.len();
254    let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
255    debug!(
256        "ailake: geometric pruning — {}/{} files survive (threshold={})",
257        surviving_files.len(),
258        total_files,
259        config.pruning_threshold
260    );
261
262    // Phase F — Bloom pruning: for hybrid queries, load per-file Bloom filters from
263    // the Puffin stats file and skip files where no query term can be present.
264    let surviving_files = if let Some(ref h) = config.hybrid {
265        let bloom_map = load_bloom_map(&table_meta, store.as_ref()).await;
266        if !bloom_map.is_empty() {
267            BloomPruner::prune(surviving_files, &h.query_text, &bloom_map)
268        } else {
269            surviving_files
270        }
271    } else {
272        surviving_files
273    };
274
275    // Phase H: load equality delete filter for this snapshot.
276    // Reads delete manifests from the catalog and downloads each equality delete Avro file.
277    // Empty filter is a no-op. On error: warn and continue with empty filter (data visible).
278    let eq_del_filter = match catalog.list_equality_deletes(table, None).await {
279        Ok(edfs) if !edfs.is_empty() => {
280            match EqualityDeleteFilter::from_files(&store, &edfs).await {
281                Ok(f) => f,
282                Err(e) => {
283                    warn!("ailake: equality delete filter build failed: {e} — rows may appear");
284                    EqualityDeleteFilter::empty()
285                }
286            }
287        }
288        _ => EqualityDeleteFilter::empty(),
289    };
290
291    // Compute candidate pool: hybrid needs a larger pool for BM25 re-ranking.
292    let candidate_k = match (&config.hybrid, config.rerank_factor) {
293        (Some(h), rf) => {
294            let pool = h.candidate_pool.unwrap_or(config.top_k * 10);
295            pool.max(rf.map_or(config.top_k, |f| f * config.top_k))
296        }
297        (None, Some(factor)) => config.top_k * factor,
298        (None, None) => config.top_k,
299    };
300
301    let use_hybrid = config.hybrid.is_some();
302
303    // Load BM25 stats from the table's stats file when hybrid search is active.
304    let bm25_stats: Option<crate::bm25::IdfStats> = if let Some(ref h) = config.hybrid {
305        if h.text_columns.is_empty() {
306            None
307        } else {
308            let stats_path = table_meta
309                .properties
310                .get(crate::bm25::BM25_STATS_PATH_PROP)
311                .map(String::as_str)
312                .unwrap_or(crate::bm25::BM25_STATS_FILE);
313            match store.get(stats_path).await {
314                Ok(bytes) => crate::bm25::IdfStats::from_bytes(&bytes).ok(),
315                Err(_) => {
316                    debug!(
317                        "ailake: BM25 stats not found at '{}' — falling back to empty corpus IDF",
318                        stats_path
319                    );
320                    None
321                }
322            }
323        }
324    } else {
325        None
326    };
327
328    // raw_candidates: (row_id, vec_dist, file_path, bm25_text) for hybrid re-ranking.
329    // Only populated when use_hybrid = true; otherwise all_results is populated directly.
330    let mut raw_candidates: Vec<(RowId, f32, String, String)> = Vec::new();
331    let mut all_results: Vec<SearchResult> = Vec::new();
332
333    for file_entry in &surviving_files {
334        let file_bytes: Bytes = store.get(&file_entry.path).await?;
335        let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
336
337        // V3 Deletion Vector: fetch bitmap once per file (range GET from Puffin .dvd).
338        // None for V2 tables or V3 files with no deletes. On fetch error: warn + continue
339        // without mask (surfacing deleted rows is safer than hard-failing the search).
340        let dv_bitmap: Option<roaring::RoaringBitmap> =
341            if let Some(ref dv) = file_entry.deletion_vector {
342                match crate::dv::load_deletion_vector(&store, dv).await {
343                    Ok(bm) => {
344                        debug!(
345                            "ailake: DV loaded ({} deletions) for {}",
346                            bm.len(),
347                            file_entry.path
348                        );
349                        Some(bm)
350                    }
351                    Err(e) => {
352                        warn!(
353                            "ailake: DV fetch failed for '{}': {e} — deleted rows may appear",
354                            file_entry.path
355                        );
356                        None
357                    }
358                }
359            } else {
360                None
361            };
362
363        // Parquet read required for: flat scan fallback, exact reranking, score_fn, hybrid,
364        // or when equality delete filter must check column values per-row.
365        let need_parquet = file_entry.index_status == IndexStatus::Indexing
366            || !reader.is_ailake_file()
367            || config.rerank_factor.is_some()
368            || config.score_fn.is_some()
369            || use_hybrid
370            || !eq_del_filter.is_empty();
371
372        if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
373            debug!(
374                "ailake: flat scan fallback for {} (index_status={:?})",
375                file_entry.path, file_entry.index_status
376            );
377            let (raw_batch, raw_vectors) = reader.read_parquet()?;
378            // Phase G: inject columns added via schema evolution with initial_default values.
379            let batch = SchemaFiller::fill(raw_batch, &table_meta.schema_fields)?;
380            for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
381                // Skip rows marked as deleted by a V3 Deletion Vector.
382                if dv_bitmap
383                    .as_ref()
384                    .is_some_and(|bm| bm.contains(row_id.as_u64() as u32))
385                {
386                    continue;
387                }
388                // Phase H: skip rows matched by an equality delete predicate.
389                if eq_del_filter.should_delete_row(&batch, row_id.as_u64() as usize) {
390                    continue;
391                }
392                if use_hybrid {
393                    let text = extract_text_for_row(
394                        &batch,
395                        row_id.as_u64() as usize,
396                        config.hybrid.as_ref().unwrap(),
397                    );
398                    raw_candidates.push((row_id, distance, file_entry.path.clone(), text));
399                } else {
400                    let final_score = apply_score_fn(&config.score_fn, distance, row_id, &batch);
401                    all_results.push(SearchResult {
402                        row_id,
403                        distance: final_score,
404                        file_path: file_entry.path.clone(),
405                    });
406                }
407            }
408            continue;
409        }
410
411        let index = reader.load_any_index_for_column(vector_column)?;
412        let local_results = index.search(query, candidate_k, config.ef_search);
413
414        let parquet_data = if need_parquet {
415            let (raw_batch, raw_vecs) = reader.read_parquet()?;
416            // Phase G: fill missing columns for old files before score_fn / hybrid BM25.
417            let filled = SchemaFiller::fill(raw_batch, &table_meta.schema_fields)?;
418            Some((filled, raw_vecs))
419        } else {
420            None
421        };
422
423        for (row_id, approx_dist) in local_results {
424            // Skip rows marked as deleted by a V3 Deletion Vector.
425            if dv_bitmap
426                .as_ref()
427                .is_some_and(|bm| bm.contains(row_id.as_u64() as u32))
428            {
429                continue;
430            }
431            let idx = row_id.as_u64() as usize;
432            // Phase H: skip rows matched by an equality delete predicate.
433            // parquet_data is always loaded when eq_del_filter is non-empty (see need_parquet).
434            if let Some((ref batch, _)) = parquet_data {
435                if eq_del_filter.should_delete_row(batch, idx) {
436                    continue;
437                }
438            }
439
440            let distance = if config.rerank_factor.is_some() {
441                match parquet_data.as_ref().and_then(|(_, vecs)| vecs.get(idx)) {
442                    Some(v) => exact_distance(metric, query, v),
443                    None => {
444                        error!(
445                            "ailake: invariant violated — row_id {} out of bounds \
446                             (file={}); Parquet and HNSW node count out of sync; \
447                             run compaction to rebuild",
448                            idx, file_entry.path
449                        );
450                        f32::INFINITY
451                    }
452                }
453            } else {
454                approx_dist
455            };
456
457            if use_hybrid {
458                let text = parquet_data.as_ref().map_or(String::new(), |(batch, _)| {
459                    extract_text_for_row(batch, idx, config.hybrid.as_ref().unwrap())
460                });
461                raw_candidates.push((row_id, distance, file_entry.path.clone(), text));
462            } else {
463                let final_score = if let Some((ref batch, _)) = parquet_data {
464                    apply_score_fn(&config.score_fn, distance, row_id, batch)
465                } else {
466                    distance
467                };
468                all_results.push(SearchResult {
469                    row_id,
470                    distance: final_score,
471                    file_path: file_entry.path.clone(),
472                });
473            }
474        }
475    }
476
477    // Hybrid BM25 fusion: applied after all HNSW candidates are collected.
478    if let Some(ref h) = config.hybrid {
479        let empty_stats = crate::bm25::IdfStats::default();
480        let stats = bm25_stats.as_ref().unwrap_or(&empty_stats);
481        let scorer = crate::bm25::BM25Scorer::new(stats);
482
483        // Compute BM25 score for each candidate.
484        let bm25_scores: Vec<f32> = raw_candidates
485            .iter()
486            .map(|(_, _, _, text)| scorer.score(&h.query_text, text))
487            .collect();
488
489        // Rank by vector distance (already sorted within each file, but merge globally).
490        raw_candidates.sort_by(|a, b| a.1.total_cmp(&b.1));
491        let vec_ranks: Vec<usize> = (0..raw_candidates.len()).collect();
492
493        // Rank by BM25 score descending (higher BM25 = better).
494        let mut bm25_indexed: Vec<(usize, f32)> = bm25_scores.iter().copied().enumerate().collect();
495        bm25_indexed.sort_by(|a, b| b.1.total_cmp(&a.1));
496        let mut bm25_rank_of = vec![0usize; raw_candidates.len()];
497        for (rank, (orig_idx, _)) in bm25_indexed.iter().enumerate() {
498            bm25_rank_of[*orig_idx] = rank;
499        }
500
501        use crate::bm25::{linear_score, rrf_score, HybridFusion};
502
503        let fused: Vec<f32> = match h.fusion {
504            HybridFusion::Rrf => vec_ranks
505                .iter()
506                .enumerate()
507                .map(|(i, &vr)| rrf_score(vr, bm25_rank_of[i], h.bm25_weight))
508                .collect(),
509            HybridFusion::Linear => {
510                let min_d = raw_candidates
511                    .iter()
512                    .map(|r| r.1)
513                    .fold(f32::INFINITY, f32::min);
514                let max_d = raw_candidates
515                    .iter()
516                    .map(|r| r.1)
517                    .fold(f32::NEG_INFINITY, f32::max);
518                let min_b = bm25_scores.iter().copied().fold(f32::INFINITY, f32::min);
519                let max_b = bm25_scores
520                    .iter()
521                    .copied()
522                    .fold(f32::NEG_INFINITY, f32::max);
523                raw_candidates
524                    .iter()
525                    .enumerate()
526                    .map(|(i, r)| {
527                        linear_score(
528                            r.1,
529                            min_d,
530                            max_d,
531                            bm25_scores[i],
532                            min_b,
533                            max_b,
534                            h.bm25_weight,
535                        )
536                    })
537                    .collect()
538            }
539        };
540
541        for (i, (row_id, _, file_path, _)) in raw_candidates.into_iter().enumerate() {
542            all_results.push(SearchResult {
543                row_id,
544                distance: fused[i],
545                file_path,
546            });
547        }
548
549        // For RRF: lower (more negative) = better; for Linear: lower = better. Same convention.
550        all_results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
551    } else {
552        all_results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
553    }
554
555    all_results.truncate(config.top_k);
556    Ok(all_results)
557}
558
559/// Extract concatenated text from specified columns for a single row.
560fn extract_text_for_row(
561    batch: &RecordBatch,
562    row_idx: usize,
563    hybrid: &crate::bm25::HybridConfig,
564) -> String {
565    use arrow_array::cast::AsArray;
566    hybrid
567        .text_columns
568        .iter()
569        .filter_map(|col| {
570            batch.column_by_name(col).and_then(|arr| {
571                arr.as_string_opt::<i32>().and_then(|sa| {
572                    if row_idx < sa.len() && sa.is_valid(row_idx) {
573                        Some(sa.value(row_idx).to_string())
574                    } else {
575                        None
576                    }
577                })
578            })
579        })
580        .collect::<Vec<_>>()
581        .join(" ")
582}
583
584/// One query arm in a cross-modal search.
585#[derive(Debug, Clone)]
586pub struct ModalQuery<'a> {
587    /// Vector column to search (must exist in the table).
588    pub column: &'a str,
589    /// Query vector for this modality.
590    pub query: &'a [f32],
591    /// Relative weight applied in the RRF formula: `weight / (k + rank)`.
592    /// `1.0` means equal weight across all modalities.
593    pub weight: f32,
594    /// Dimensionality of this column's vectors. `0` = auto-detect from table metadata
595    /// (`ailake.dim-<column>` for secondary columns, `ailake.vector-dim` for primary).
596    pub dim: u32,
597}
598
599/// Fusion method for combining results from multiple vector columns.
600#[derive(Debug, Clone, Copy, PartialEq, Eq)]
601pub enum FusionMethod {
602    /// Reciprocal Rank Fusion: `score(d) = Σ weight_i / (k + rank_i(d))`.
603    /// `k = 60` (standard). Returned `SearchResult.distance` = `-rrf_score`
604    /// so that sort-ascending-by-distance gives the correct RRF ranking.
605    Rrf,
606}
607
608/// Cross-modal search: run independent HNSW searches across N vector columns,
609/// then fuse per-column ranked lists using Reciprocal Rank Fusion.
610///
611/// Each `ModalQuery` specifies a column name, its query vector, RRF weight, and dim.
612/// When `ModalQuery.dim == 0`, the dim is auto-detected from `ailake.dim-<col>` /
613/// `ailake.vector-dim` in table metadata.
614/// Results are de-duplicated by `(file_path, row_id)` and ranked by aggregate
615/// RRF score. `SearchResult.distance` stores `-rrf_score` (lower = better) so
616/// existing sort-ascending callers get the correct ordering.
617pub async fn search_multimodal(
618    table: &TableIdent,
619    queries: &[ModalQuery<'_>],
620    config: SearchConfig,
621    catalog: Arc<dyn CatalogProvider>,
622    store: Arc<dyn Store>,
623    fusion: FusionMethod,
624) -> AilakeResult<Vec<SearchResult>> {
625    use std::collections::HashMap;
626
627    if queries.is_empty() {
628        return Err(AilakeError::InvalidArgument(
629            "search_multimodal requires at least one ModalQuery".into(),
630        ));
631    }
632
633    // Load table metadata once for dim auto-detection and metric resolution.
634    let table_meta = catalog.load_table(table).await?;
635    let primary_col = table_meta
636        .properties
637        .get("ailake.vector-column")
638        .cloned()
639        .unwrap_or_default();
640    let primary_dim: u32 = table_meta
641        .properties
642        .get("ailake.vector-dim")
643        .and_then(|s| s.parse().ok())
644        .unwrap_or(0);
645
646    // Fetch more candidates per column so RRF has enough to fuse.
647    let per_col_k = (config.top_k * queries.len().max(2)).min(1000);
648
649    let mut per_col_results: Vec<(f32, Vec<SearchResult>)> = Vec::with_capacity(queries.len());
650    for mq in queries {
651        // Resolve dim: caller-supplied > per-column property > primary column dim.
652        let resolved_dim = if mq.dim > 0 {
653            mq.dim
654        } else if mq.column == primary_col {
655            primary_dim
656        } else {
657            table_meta
658                .properties
659                .get(&format!("ailake.dim-{}", mq.column))
660                .and_then(|s| s.parse().ok())
661                .unwrap_or(mq.query.len() as u32)
662        };
663
664        let col_config = SearchConfig {
665            top_k: per_col_k,
666            ef_search: config.ef_search,
667            pruning_threshold: config.pruning_threshold,
668            rerank_factor: config.rerank_factor,
669            score_fn: None,
670            partition_filter: config.partition_filter.clone(),
671            hybrid: None,
672        };
673        let results = search(
674            table,
675            mq.query,
676            col_config,
677            mq.column,
678            resolved_dim,
679            catalog.clone(),
680            store.clone(),
681        )
682        .await?;
683        per_col_results.push((mq.weight, results));
684    }
685
686    // RRF fusion: accumulate score per (file_path, row_id).
687    const K: f32 = 60.0;
688    let mut scores: HashMap<(String, u64), f32> = HashMap::new();
689
690    for (weight, results) in &per_col_results {
691        for (rank, r) in results.iter().enumerate() {
692            let key = (r.file_path.clone(), r.row_id.as_u64());
693            let rrf = weight / (K + rank as f32 + 1.0);
694            *scores.entry(key).or_insert(0.0) += rrf;
695        }
696    }
697
698    // Build SearchResult list sorted by descending RRF score.
699    // Store `-rrf_score` as `.distance` so callers sorting ascending get correct order.
700    let all_files = catalog.list_files(table, None).await?;
701    let _ = all_files; // centroid not needed for fusion — just need file_path+row_id
702
703    // Collect unique candidates: prefer the row's appearance in the first column's results.
704    let mut seen: HashMap<(String, u64), f32> = HashMap::new();
705    for (_, results) in &per_col_results {
706        for r in results {
707            let key = (r.file_path.clone(), r.row_id.as_u64());
708            let rrf_score = *scores.get(&key).unwrap_or(&0.0);
709            seen.entry(key).or_insert(rrf_score);
710        }
711    }
712
713    let mut fused: Vec<SearchResult> = seen
714        .into_iter()
715        .map(|((file_path, row_id_u64), rrf_score)| SearchResult {
716            row_id: RowId::new(row_id_u64),
717            distance: -rrf_score,
718            file_path,
719        })
720        .collect();
721
722    fused.sort_by(|a, b| {
723        a.distance
724            .partial_cmp(&b.distance)
725            .unwrap_or(std::cmp::Ordering::Equal)
726    });
727    fused.truncate(config.top_k);
728
729    let _ = fusion; // only RRF implemented; enum is extensible
730
731    Ok(fused)
732}
733
734/// Apply `score_fn` to a single result row, or return `distance` unchanged.
735///
736/// Slices the batch to a 1-row RecordBatch at `row_id` and calls the fn.
737/// If `score_fn` is `None` or the row index is out of bounds, returns `distance`.
738#[inline]
739fn apply_score_fn(
740    score_fn: &Option<ScoreFn>,
741    distance: f32,
742    row_id: RowId,
743    batch: &RecordBatch,
744) -> f32 {
745    match score_fn {
746        None => distance,
747        Some(f) => {
748            let idx = row_id.as_u64() as usize;
749            if idx < batch.num_rows() {
750                f.call(distance, &batch.slice(idx, 1))
751            } else {
752                distance
753            }
754        }
755    }
756}
757
758/// Brute-force top-k search over raw vectors. Used for Indexing shards.
759fn flat_search(
760    raw: &[Vec<f32>],
761    query: &[f32],
762    top_k: usize,
763    metric: VectorMetric,
764) -> Vec<(RowId, f32)> {
765    let mut results: Vec<(RowId, f32)> = raw
766        .iter()
767        .enumerate()
768        .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
769        .collect();
770    results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
771    results.truncate(top_k);
772    results
773}
774
775fn parse_metric(s: &str) -> VectorMetric {
776    match s {
777        "euclidean" => VectorMetric::Euclidean,
778        "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
779        _ => VectorMetric::Cosine,
780    }
781}
782
783/// Pre-loaded search session: all HNSW indexes loaded into memory once.
784///
785/// Useful for benchmarks and servers that issue many queries against the same
786/// snapshot. Avoids re-loading and re-deserializing indexes on every call.
787pub struct SearchSession {
788    shards: Vec<LoadedShard>,
789    metric: VectorMetric,
790}
791
792struct LoadedShard {
793    entry: DataFileEntry,
794    /// None when the shard is still being indexed (IndexStatus::Indexing).
795    index: Option<AnyIndex>,
796    /// Raw F32 vectors: always present for Indexing shards (flat scan), optionally
797    /// present for Ready shards when `load_raw = true` (reranking).
798    raw_vectors: Option<Vec<Vec<f32>>>,
799}
800
801impl SearchSession {
802    /// Load all indexes for the latest snapshot into memory.
803    ///
804    /// Pass `load_raw = true` when reranking will be used (`rerank_factor` is
805    /// `Some`); it reads the full parquet columns so exact distances are
806    /// available without extra I/O during `search_query`.
807    pub async fn load(
808        table: &TableIdent,
809        vector_column: &str,
810        dim: u32,
811        catalog: Arc<dyn CatalogProvider>,
812        store: Arc<dyn Store>,
813        load_raw: bool,
814    ) -> AilakeResult<Self> {
815        let all_files = catalog.list_files(table, None).await?;
816        let table_meta = catalog.load_table(table).await?;
817        let metric = parse_metric(
818            table_meta
819                .properties
820                .get("ailake.vector-metric")
821                .map(String::as_str)
822                .unwrap_or("cosine"),
823        );
824
825        let mut shards = Vec::with_capacity(all_files.len());
826        for entry in all_files {
827            let file_bytes: Bytes = store.get(&entry.path).await?;
828            let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
829
830            if entry.index_status == IndexStatus::Indexing {
831                // HNSW not yet built — load raw vectors for flat scan.
832                let (_, raw_vecs) = reader.read_parquet()?;
833                shards.push(LoadedShard {
834                    entry,
835                    index: None,
836                    raw_vectors: Some(raw_vecs),
837                });
838            } else if reader.is_ailake_file() {
839                let mut index = reader.load_any_index_for_column(vector_column)?;
840                let raw_vectors = if load_raw {
841                    index.quantize_to_f16();
842                    let (_, vecs) = reader.read_parquet()?;
843                    Some(vecs)
844                } else {
845                    None
846                };
847                shards.push(LoadedShard {
848                    entry,
849                    index: Some(index),
850                    raw_vectors,
851                });
852            }
853        }
854
855        Ok(Self { shards, metric })
856    }
857
858    /// Number of loaded shards.
859    pub fn shard_count(&self) -> usize {
860        self.shards.len()
861    }
862
863    /// Search multiple queries in one call.
864    ///
865    /// For shards with raw vectors (Indexing or reranking): dispatches to GPU batch
866    /// matmul when a CUDA device is available, falling back to CPU flat scan.
867    /// For indexed shards (HNSW / IVF-PQ): rayon parallel-map over queries — graph
868    /// traversal is inherently sequential and has no GPU batch path.
869    ///
870    /// Returns one `Vec<SearchResult>` per input query, in the same order.
871    pub fn search_batch(
872        &self,
873        queries: &[Vec<f32>],
874        config: &SearchConfig,
875    ) -> Vec<Vec<SearchResult>> {
876        if queries.is_empty() {
877            return vec![];
878        }
879
880        let n_queries = queries.len();
881        let candidate_k = match config.rerank_factor {
882            Some(factor) => config.top_k * factor,
883            None => config.top_k,
884        };
885        let use_nvidia = ailake_index::hardware::detect_cuda();
886        let use_amd = ailake_index::hardware::detect_rocm();
887
888        // Accumulate per-query results across all shards.
889        let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
890
891        for shard in &self.shards {
892            if let Some(raw) = &shard.raw_vectors {
893                // Flat-scan shard — try GPU batch path (NVIDIA first, then AMD ROCm).
894                if !raw.is_empty() {
895                    let dim = raw[0].len();
896                    let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
897                    let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
898                    let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
899
900                    let gpu_batch = if use_nvidia {
901                        ailake_index::gpu::try_nvidia_search_batch(
902                            &q_refs,
903                            &row_ids,
904                            &flat,
905                            dim,
906                            self.metric,
907                            candidate_k,
908                        )
909                    } else if use_amd {
910                        ailake_index::gpu::try_rocm_search_batch(
911                            &q_refs,
912                            &row_ids,
913                            &flat,
914                            dim,
915                            self.metric,
916                            candidate_k,
917                        )
918                    } else {
919                        None
920                    };
921
922                    if let Some(batch) = gpu_batch {
923                        for (qi, results) in batch.into_iter().enumerate() {
924                            for (row_id, distance) in results {
925                                all_results[qi].push(SearchResult {
926                                    row_id,
927                                    distance,
928                                    file_path: shard.entry.path.clone(),
929                                });
930                            }
931                        }
932                        continue;
933                    }
934                }
935
936                // CPU fallback for flat scan.
937                for (qi, query) in queries.iter().enumerate() {
938                    for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
939                        all_results[qi].push(SearchResult {
940                            row_id,
941                            distance,
942                            file_path: shard.entry.path.clone(),
943                        });
944                    }
945                }
946            } else if let Some(index) = &shard.index {
947                // Indexed shard — rayon parallel-map over queries.
948                let shard_results: Vec<Vec<SearchResult>> = queries
949                    .par_iter()
950                    .map(|query| {
951                        index
952                            .search(query, candidate_k, config.ef_search)
953                            .into_iter()
954                            .map(|(row_id, distance)| SearchResult {
955                                row_id,
956                                distance,
957                                file_path: shard.entry.path.clone(),
958                            })
959                            .collect()
960                    })
961                    .collect();
962
963                for (qi, results) in shard_results.into_iter().enumerate() {
964                    all_results[qi].extend(results);
965                }
966            }
967        }
968
969        // Sort + truncate per query.
970        for results in &mut all_results {
971            results.sort_by(|a, b| {
972                a.distance
973                    .partial_cmp(&b.distance)
974                    .unwrap_or(std::cmp::Ordering::Equal)
975            });
976            results.truncate(config.top_k);
977        }
978
979        all_results
980    }
981
982    /// Search using pre-loaded indexes. No I/O — pure in-memory search.
983    pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
984        let candidate_k = match config.rerank_factor {
985            Some(factor) => config.top_k * factor,
986            None => config.top_k,
987        };
988
989        let mut all_results: Vec<SearchResult> = self
990            .shards
991            .par_iter()
992            .flat_map(|shard| {
993                // Geometric pruning per shard.
994                if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
995                    let dist = match self.metric {
996                        VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
997                            ailake_vec::cosine_distance(query, &centroid.values)
998                        }
999                        VectorMetric::Euclidean => {
1000                            ailake_vec::euclidean_distance(query, &centroid.values)
1001                        }
1002                        VectorMetric::DotProduct => {
1003                            -ailake_vec::dot_product(query, &centroid.values)
1004                        }
1005                    };
1006                    if dist - centroid.radius > config.pruning_threshold {
1007                        return vec![];
1008                    }
1009                }
1010
1011                if let Some(index) = &shard.index {
1012                    // Ready shard: HNSW or IVF-PQ search (dispatched by AnyIndex).
1013                    let local_results = index.search(query, candidate_k, config.ef_search);
1014                    if config.rerank_factor.is_some() {
1015                        if let Some(raw) = &shard.raw_vectors {
1016                            local_results
1017                                .into_iter()
1018                                .map(|(row_id, _approx_dist)| {
1019                                    let idx = row_id.as_u64() as usize;
1020                                    let exact_dist = raw
1021                                        .get(idx)
1022                                        .map(|v| exact_distance(self.metric, query, v))
1023                                        .unwrap_or(f32::INFINITY);
1024                                    SearchResult {
1025                                        row_id,
1026                                        distance: exact_dist,
1027                                        file_path: shard.entry.path.clone(),
1028                                    }
1029                                })
1030                                .collect()
1031                        } else {
1032                            local_results
1033                                .into_iter()
1034                                .map(|(row_id, distance)| SearchResult {
1035                                    row_id,
1036                                    distance,
1037                                    file_path: shard.entry.path.clone(),
1038                                })
1039                                .collect()
1040                        }
1041                    } else {
1042                        local_results
1043                            .into_iter()
1044                            .map(|(row_id, distance)| SearchResult {
1045                                row_id,
1046                                distance,
1047                                file_path: shard.entry.path.clone(),
1048                            })
1049                            .collect()
1050                    }
1051                } else if let Some(raw) = &shard.raw_vectors {
1052                    // Indexing shard: exact flat scan.
1053                    flat_search(raw, query, candidate_k, self.metric)
1054                        .into_iter()
1055                        .map(|(row_id, distance)| SearchResult {
1056                            row_id,
1057                            distance,
1058                            file_path: shard.entry.path.clone(),
1059                        })
1060                        .collect()
1061                } else {
1062                    vec![]
1063                }
1064            })
1065            .collect();
1066
1067        all_results.sort_by(|a, b| {
1068            a.distance
1069                .partial_cmp(&b.distance)
1070                .unwrap_or(std::cmp::Ordering::Equal)
1071        });
1072        all_results.truncate(config.top_k);
1073        all_results
1074    }
1075}
1076
1077/// Pure BM25 full-text search across all Parquet files in the table.
1078///
1079/// Scans every surviving file (O(N) complexity), scores each row with BM25 against
1080/// `query_text`, and returns the global top-k by score. IDF stats are loaded from
1081/// `metadata/ailake_bm25_stats.bin` (written by `TableWriter` when `bm25_text_column`
1082/// is configured). If the stats file is absent, IDF defaults to an empty corpus
1083/// (all terms treated as maximally rare — directionally correct but less precise).
1084///
1085/// For pure-lexical search at scale (millions of rows, hundreds of files), consider
1086/// using SQL `LIKE` / `ILIKE` via DuckDB/Trino over the Iceberg-compatible table.
1087/// This function is best suited for small-medium tables or as a lexical complement
1088/// to `search()` for tables where the document count per file is manageable.
1089pub async fn search_text(
1090    table: &TableIdent,
1091    query_text: &str,
1092    text_columns: &[&str],
1093    top_k: usize,
1094    catalog: Arc<dyn CatalogProvider>,
1095    store: Arc<dyn Store>,
1096    partition_filter: Option<&str>,
1097) -> AilakeResult<Vec<SearchResult>> {
1098    use arrow_array::cast::AsArray;
1099
1100    if text_columns.is_empty() {
1101        return Err(AilakeError::InvalidArgument(
1102            "search_text requires at least one text column".into(),
1103        ));
1104    }
1105
1106    let all_files = catalog.list_files(table, None).await?;
1107    let table_meta = catalog.load_table(table).await?;
1108
1109    // Partition pruning
1110    let files: Vec<_> = if let Some(pv) = partition_filter {
1111        all_files
1112            .into_iter()
1113            .filter(|f| f.partition_value.as_deref() == Some(pv))
1114            .collect()
1115    } else {
1116        all_files
1117    };
1118
1119    // Load BM25 stats
1120    let stats_path = table_meta
1121        .properties
1122        .get(crate::bm25::BM25_STATS_PATH_PROP)
1123        .map(String::as_str)
1124        .unwrap_or(crate::bm25::BM25_STATS_FILE);
1125    let stats = match store.get(stats_path).await {
1126        Ok(bytes) => crate::bm25::IdfStats::from_bytes(&bytes).unwrap_or_default(),
1127        Err(_) => {
1128            debug!(
1129                "ailake: BM25 stats not found at '{}' — using empty corpus IDF",
1130                stats_path
1131            );
1132            crate::bm25::IdfStats::default()
1133        }
1134    };
1135    let scorer = crate::bm25::BM25Scorer::new(&stats);
1136
1137    // Phase H: equality delete filter for search_text results.
1138    let eq_del_filter = match catalog.list_equality_deletes(table, None).await {
1139        Ok(edfs) if !edfs.is_empty() => {
1140            match EqualityDeleteFilter::from_files(&store, &edfs).await {
1141                Ok(f) => f,
1142                Err(e) => {
1143                    warn!("ailake: equality delete filter build failed in search_text: {e}");
1144                    EqualityDeleteFilter::empty()
1145                }
1146            }
1147        }
1148        _ => EqualityDeleteFilter::empty(),
1149    };
1150
1151    let mut results: Vec<SearchResult> = Vec::new();
1152
1153    for file_entry in &files {
1154        let file_bytes = store.get(&file_entry.path).await?;
1155        // Use dim=0 — we only read the Parquet columns, not the HNSW.
1156        let reader = AilakeFileReader::new(file_bytes, "", 0);
1157        let (raw_batch, _) = reader.read_parquet()?;
1158        // Phase G: fill missing columns for old files before BM25 text extraction.
1159        let batch = SchemaFiller::fill(raw_batch, &table_meta.schema_fields)?;
1160
1161        for row_idx in 0..batch.num_rows() {
1162            // Phase H: skip rows matched by equality delete predicate.
1163            if eq_del_filter.should_delete_row(&batch, row_idx) {
1164                continue;
1165            }
1166            let doc_text: String = text_columns
1167                .iter()
1168                .filter_map(|&col| {
1169                    batch.column_by_name(col).and_then(|arr| {
1170                        arr.as_string_opt::<i32>().and_then(|sa| {
1171                            if sa.is_valid(row_idx) {
1172                                Some(sa.value(row_idx).to_string())
1173                            } else {
1174                                None
1175                            }
1176                        })
1177                    })
1178                })
1179                .collect::<Vec<_>>()
1180                .join(" ");
1181
1182            if doc_text.is_empty() {
1183                continue;
1184            }
1185
1186            let bm25 = scorer.score(query_text, &doc_text);
1187            if bm25 > 0.0 {
1188                // Negate so that sort-ascending = best-first (lower distance = higher BM25).
1189                results.push(SearchResult {
1190                    row_id: RowId::new(row_idx as u64),
1191                    distance: -bm25,
1192                    file_path: file_entry.path.clone(),
1193                });
1194            }
1195        }
1196    }
1197
1198    results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
1199    results.truncate(top_k);
1200    Ok(results)
1201}
1202
1203/// Fetch full row data for a slice of search results.
1204///
1205/// Groups results by Parquet file, reads each file once, extracts the matching rows
1206/// via `arrow_select::take`, then concatenates everything back in original top-k order
1207/// with a `_distance: Float32` column appended.
1208///
1209/// Use this immediately after `search()` to retrieve the actual text / metadata
1210/// columns (e.g. `chunk_text`, `document_title`) alongside the distance scores.
1211pub async fn fetch_rows(
1212    results: &[SearchResult],
1213    store: Arc<dyn Store>,
1214    vector_column: &str,
1215    dim: u32,
1216) -> AilakeResult<RecordBatch> {
1217    use std::collections::HashMap;
1218
1219    use arrow_array::{ArrayRef, Float32Array, UInt32Array};
1220    use arrow_schema::{DataType, Field, Schema};
1221    use arrow_select::{concat::concat_batches, take::take};
1222
1223    if results.is_empty() {
1224        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
1225    }
1226
1227    // Group by file path; preserve original position for re-sorting.
1228    let mut by_file: HashMap<&str, Vec<(u64, f32, usize)>> = HashMap::new();
1229    for (i, r) in results.iter().enumerate() {
1230        by_file
1231            .entry(r.file_path.as_str())
1232            .or_default()
1233            .push((r.row_id.as_u64(), r.distance, i));
1234    }
1235
1236    use arrow_array::FixedSizeListArray;
1237
1238    // (original_index, distance, single-row RecordBatch, decoded F32 vector)
1239    let mut collected: Vec<(usize, f32, RecordBatch, Vec<f32>)> = Vec::with_capacity(results.len());
1240
1241    for (file_path, rows) in &by_file {
1242        let bytes = store.get(file_path).await?;
1243        let reader = AilakeFileReader::new(bytes, vector_column, dim);
1244        let (batch, vectors) = reader.read_parquet()?;
1245
1246        for &(row_id, distance, pos) in rows {
1247            let idx = row_id as usize;
1248            if idx >= batch.num_rows() {
1249                tracing::warn!(
1250                    "fetch_rows: row_id {} out of bounds (file_rows={}, file={}), skipping",
1251                    idx,
1252                    batch.num_rows(),
1253                    file_path
1254                );
1255                continue;
1256            }
1257
1258            let indices = UInt32Array::from(vec![idx as u32]);
1259            let row_cols: Vec<ArrayRef> = batch
1260                .columns()
1261                .iter()
1262                .map(|col| {
1263                    take(col.as_ref(), &indices, None)
1264                        .map_err(|e| AilakeError::Arrow(e.to_string()))
1265                })
1266                .collect::<AilakeResult<Vec<_>>>()?;
1267
1268            let row_batch = RecordBatch::try_new(batch.schema(), row_cols)
1269                .map_err(|e| AilakeError::Arrow(e.to_string()))?;
1270
1271            // Capture decoded F32 vector for this row (empty vec if not available).
1272            let vec = vectors
1273                .get(idx)
1274                .cloned()
1275                .unwrap_or_else(|| vec![0.0f32; dim as usize]);
1276
1277            collected.push((pos, distance, row_batch, vec));
1278        }
1279    }
1280
1281    if collected.is_empty() {
1282        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
1283    }
1284
1285    // Restore original top-k order from the search results slice.
1286    collected.sort_by_key(|(pos, _, _, _)| *pos);
1287
1288    let distances: Vec<f32> = collected.iter().map(|(_, d, _, _)| *d).collect();
1289    let row_batches: Vec<&RecordBatch> = collected.iter().map(|(_, _, b, _)| b).collect();
1290    let base_schema = collected[0].2.schema();
1291
1292    let combined =
1293        concat_batches(&base_schema, row_batches).map_err(|e| AilakeError::Arrow(e.to_string()))?;
1294
1295    // Build FixedSizeList<Float32> column with decoded vectors (F32, not raw F16 bytes).
1296    let flat_vecs: Vec<f32> = collected
1297        .iter()
1298        .flat_map(|(_, _, _, v)| v.iter().copied())
1299        .collect();
1300    let item_field = Arc::new(Field::new("item", DataType::Float32, false));
1301    let values_arr = Arc::new(Float32Array::from(flat_vecs)) as ArrayRef;
1302    let vec_col = FixedSizeListArray::new(item_field.clone(), dim as i32, values_arr, None);
1303    let vec_field = Arc::new(Field::new(
1304        vector_column,
1305        DataType::FixedSizeList(item_field, dim as i32),
1306        false,
1307    ));
1308
1309    // Schema: tabular cols, then decoded vector col, then _distance.
1310    let mut fields: Vec<Arc<Field>> = base_schema.fields().to_vec();
1311    fields.push(vec_field);
1312    fields.push(Arc::new(Field::new("_distance", DataType::Float32, false)));
1313    let new_schema = Arc::new(Schema::new(fields));
1314
1315    let mut columns: Vec<ArrayRef> = combined.columns().to_vec();
1316    columns.push(Arc::new(vec_col));
1317    columns.push(Arc::new(Float32Array::from(distances)));
1318
1319    RecordBatch::try_new(new_schema, columns).map_err(|e| AilakeError::Arrow(e.to_string()))
1320}
1321
1322/// Load per-file BM25 Bloom filters from the Puffin stats file for the current snapshot.
1323///
1324/// Returns a map of `file_path → BloomFilter`. Empty map = no stats file available
1325/// (V2 table, first write, or fetch failure). The scanner applies Bloom pruning only
1326/// when the map is non-empty.
1327async fn load_bloom_map(
1328    table_meta: &ailake_catalog::TableMetadata,
1329    store: &dyn Store,
1330) -> std::collections::HashMap<String, crate::bloom::BloomFilter> {
1331    let stats_path = match &table_meta.current_statistics_path {
1332        Some(p) => p.clone(),
1333        None => return std::collections::HashMap::new(),
1334    };
1335    let bytes = match store.get(&stats_path).await {
1336        Ok(b) => b,
1337        Err(e) => {
1338            debug!("ailake: Phase F — could not load Puffin stats ({stats_path}): {e}");
1339            return std::collections::HashMap::new();
1340        }
1341    };
1342    let reader = ailake_catalog::AilakePuffinReader::new(&bytes);
1343    let bloom_entries = match reader.read_bm25_blooms() {
1344        Ok(e) => e,
1345        Err(e) => {
1346            warn!("ailake: Phase F — Puffin bloom parse error: {e}");
1347            return std::collections::HashMap::new();
1348        }
1349    };
1350    bloom_entries
1351        .into_iter()
1352        .filter_map(|entry| {
1353            let bf = crate::bloom::BloomFilter::from_bytes(&entry.bloom_bytes)?;
1354            Some((entry.path, bf))
1355        })
1356        .collect()
1357}
1358
1359#[cfg(test)]
1360mod tests {
1361    use super::*;
1362    use crate::writer::MultiVectorBatch;
1363    use ailake_catalog::{HadoopCatalog, TableIdent};
1364    use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
1365    use ailake_store::LocalStore;
1366    use arrow_array::{Int32Array, RecordBatch};
1367    use arrow_schema::{DataType, Field, Schema};
1368    use std::sync::Arc;
1369    use tempfile::TempDir;
1370
1371    fn make_policy(dim: u32) -> VectorStoragePolicy {
1372        VectorStoragePolicy {
1373            column_name: "embedding".to_string(),
1374            dim,
1375            metric: VectorMetric::Cosine,
1376            precision: VectorPrecision::F16,
1377            pq: None,
1378            keep_raw_for_reranking: true,
1379            pre_normalize: false,
1380            hnsw_m: None,
1381            hnsw_ef_construction: None,
1382            ivf_residual: false,
1383            embedding_model: None,
1384            modality: None,
1385            partition_by: None,
1386            partition_value: None,
1387            partition_column_type: None,
1388            partition_fields: vec![],
1389        }
1390    }
1391
1392    async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
1393        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
1394        let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
1395        let table = TableIdent::new("default", "table");
1396
1397        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1398        let ids: Vec<i32> = (0..rows as i32).collect();
1399        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
1400
1401        // Each row i has embedding with 1.0 at dimension i and 0 elsewhere (unit basis vectors)
1402        let embeddings: Vec<Vec<f32>> = (0..rows)
1403            .map(|i| {
1404                let mut v = vec![0.0f32; dim];
1405                v[i % dim] = 1.0;
1406                v
1407            })
1408            .collect();
1409
1410        let mut writer =
1411            crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table, 2)
1412                .await
1413                .unwrap();
1414        writer.write_batch(&batch, &embeddings).await.unwrap();
1415        writer.commit().await.unwrap();
1416    }
1417
1418    #[tokio::test]
1419    async fn rerank_returns_correct_top_k_count() {
1420        let dir = TempDir::new().unwrap();
1421        let dim = 8usize;
1422        write_demo_table(&dir, dim, 8).await;
1423
1424        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
1425        let catalog: Arc<dyn CatalogProvider> =
1426            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
1427        let table = TableIdent::new("default", "table");
1428
1429        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
1430        let config = SearchConfig {
1431            top_k: 3,
1432            ef_search: 50,
1433            pruning_threshold: f32::INFINITY,
1434            rerank_factor: Some(2),
1435            score_fn: None,
1436            partition_filter: None,
1437            hybrid: None,
1438        };
1439
1440        let results = search(
1441            &table,
1442            &query,
1443            config,
1444            "embedding",
1445            dim as u32,
1446            catalog,
1447            store,
1448        )
1449        .await
1450        .unwrap();
1451
1452        assert_eq!(results.len(), 3);
1453    }
1454
1455    #[tokio::test]
1456    async fn rerank_nearest_is_exact_match() {
1457        let dir = TempDir::new().unwrap();
1458        let dim = 8usize;
1459        write_demo_table(&dir, dim, 8).await;
1460
1461        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
1462        let catalog: Arc<dyn CatalogProvider> =
1463            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
1464        let table = TableIdent::new("default", "table");
1465
1466        // Row 0 has [1,0,0,...] — cosine distance to same query is 0
1467        let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
1468        let config = SearchConfig {
1469            top_k: 1,
1470            ef_search: 50,
1471            pruning_threshold: f32::INFINITY,
1472            rerank_factor: Some(4),
1473            score_fn: None,
1474            partition_filter: None,
1475            hybrid: None,
1476        };
1477
1478        let results = search(
1479            &table,
1480            &query,
1481            config,
1482            "embedding",
1483            dim as u32,
1484            catalog,
1485            store,
1486        )
1487        .await
1488        .unwrap();
1489
1490        assert_eq!(results.len(), 1);
1491        // Exact cosine distance between identical unit vectors is ~0 (F16 rounding allowed)
1492        assert!(
1493            results[0].distance < 1e-3,
1494            "distance was {}",
1495            results[0].distance
1496        );
1497        assert_eq!(results[0].row_id, RowId::new(0));
1498    }
1499
1500    #[tokio::test]
1501    async fn no_rerank_matches_default_behavior() {
1502        let dir = TempDir::new().unwrap();
1503        let dim = 4usize;
1504        write_demo_table(&dir, dim, 4).await;
1505
1506        let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
1507        let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
1508        let cat_a: Arc<dyn CatalogProvider> =
1509            Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
1510        let cat_b: Arc<dyn CatalogProvider> =
1511            Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
1512        let table = TableIdent::new("default", "table");
1513
1514        let query = vec![1.0f32, 0.0, 0.0, 0.0];
1515        let cfg_plain = SearchConfig {
1516            top_k: 2,
1517            ef_search: 50,
1518            pruning_threshold: f32::INFINITY,
1519            rerank_factor: None,
1520            score_fn: None,
1521            partition_filter: None,
1522            hybrid: None,
1523        };
1524        let cfg_rerank = SearchConfig {
1525            top_k: 2,
1526            ef_search: 50,
1527            pruning_threshold: f32::INFINITY,
1528            rerank_factor: Some(2),
1529            score_fn: None,
1530            partition_filter: None,
1531            hybrid: None,
1532        };
1533
1534        let plain = search(
1535            &table,
1536            &query,
1537            cfg_plain,
1538            "embedding",
1539            dim as u32,
1540            cat_a,
1541            store_a,
1542        )
1543        .await
1544        .unwrap();
1545        let reranked = search(
1546            &table,
1547            &query,
1548            cfg_rerank,
1549            "embedding",
1550            dim as u32,
1551            cat_b,
1552            store_b,
1553        )
1554        .await
1555        .unwrap();
1556
1557        // Both should return same top-1 result (row 0, distance ~0)
1558        assert_eq!(plain[0].row_id, reranked[0].row_id);
1559    }
1560
1561    #[tokio::test]
1562    async fn multimodal_rrf_returns_top_k() {
1563        let dir = TempDir::new().unwrap();
1564        let dim = 4usize;
1565        write_demo_table(&dir, dim, 4).await;
1566
1567        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
1568        let catalog: Arc<dyn CatalogProvider> =
1569            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
1570        let table = TableIdent::new("default", "table");
1571
1572        // Two modal queries using the same column (single-column table).
1573        // Different queries to exercise RRF merging.
1574        let q1 = vec![1.0f32, 0.0, 0.0, 0.0];
1575        let q2 = vec![0.0f32, 1.0, 0.0, 0.0];
1576
1577        let queries = vec![
1578            ModalQuery {
1579                column: "embedding",
1580                query: &q1,
1581                weight: 0.7,
1582                dim: dim as u32,
1583            },
1584            ModalQuery {
1585                column: "embedding",
1586                query: &q2,
1587                weight: 0.3,
1588                dim: dim as u32,
1589            },
1590        ];
1591
1592        let config = SearchConfig {
1593            top_k: 2,
1594            ef_search: 50,
1595            pruning_threshold: f32::INFINITY,
1596            rerank_factor: None,
1597            score_fn: None,
1598            partition_filter: None,
1599            hybrid: None,
1600        };
1601
1602        let results =
1603            search_multimodal(&table, &queries, config, catalog, store, FusionMethod::Rrf)
1604                .await
1605                .unwrap();
1606
1607        assert_eq!(results.len(), 2);
1608        // RRF score stored as -distance; all should be negative
1609        assert!(results[0].distance <= 0.0);
1610        // Top result should be one of rows 0 or 1 (nearest to q1 or q2)
1611        assert!(results[0].row_id.as_u64() < 4);
1612    }
1613
1614    /// True cross-modal test: two columns with DIFFERENT dims (4 + 2).
1615    /// Verifies that search_multimodal correctly routes to each column's HNSW
1616    /// and that the dim validation in search() handles secondary columns.
1617    #[tokio::test]
1618    async fn multimodal_rrf_cross_modal_different_dims() {
1619        let dir = TempDir::new().unwrap();
1620        let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
1621        let catalog: Arc<dyn CatalogProvider> =
1622            Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
1623        let table = TableIdent::new("default", "table");
1624
1625        // Write a 2-column table: "embedding" dim=4, "img_embedding" dim=2
1626        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1627        let rows = 4usize;
1628        let ids: Vec<i32> = (0..rows as i32).collect();
1629        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
1630
1631        let text_embs: Vec<Vec<f32>> = (0..rows)
1632            .map(|i| {
1633                let mut v = vec![0.0f32; 4];
1634                v[i % 4] = 1.0;
1635                v
1636            })
1637            .collect();
1638        let img_embs: Vec<Vec<f32>> = (0..rows)
1639            .map(|i| {
1640                let mut v = vec![0.0f32; 2];
1641                v[i % 2] = 1.0;
1642                v
1643            })
1644            .collect();
1645
1646        let text_policy = make_policy(4);
1647        let img_policy = VectorStoragePolicy {
1648            column_name: "img_embedding".to_string(),
1649            dim: 2,
1650            metric: VectorMetric::Cosine,
1651            precision: VectorPrecision::F16,
1652            pq: None,
1653            keep_raw_for_reranking: true,
1654            pre_normalize: false,
1655            hnsw_m: None,
1656            hnsw_ef_construction: None,
1657            ivf_residual: false,
1658            embedding_model: None,
1659            modality: None,
1660            partition_by: None,
1661            partition_value: None,
1662            partition_column_type: None,
1663            partition_fields: vec![],
1664        };
1665
1666        let mut writer = crate::TableWriter::create_or_open(
1667            catalog.clone(),
1668            store.clone(),
1669            text_policy,
1670            table.clone(),
1671            2,
1672        )
1673        .await
1674        .unwrap();
1675
1676        let batches = [
1677            MultiVectorBatch {
1678                policy: make_policy(4),
1679                embeddings: &text_embs,
1680            },
1681            MultiVectorBatch {
1682                policy: img_policy,
1683                embeddings: &img_embs,
1684            },
1685        ];
1686        writer.write_batch_multi(&batch, &batches).await.unwrap();
1687        writer.commit().await.unwrap();
1688
1689        // Cross-modal search: text query (dim=4) + image query (dim=2).
1690        let q_text = vec![1.0f32, 0.0, 0.0, 0.0];
1691        let q_img = vec![1.0f32, 0.0];
1692
1693        let queries = vec![
1694            ModalQuery {
1695                column: "embedding",
1696                query: &q_text,
1697                weight: 0.6,
1698                dim: 4,
1699            },
1700            ModalQuery {
1701                column: "img_embedding",
1702                query: &q_img,
1703                weight: 0.4,
1704                dim: 2,
1705            },
1706        ];
1707        let config = SearchConfig {
1708            top_k: 2,
1709            ef_search: 50,
1710            pruning_threshold: f32::INFINITY,
1711            rerank_factor: None,
1712            score_fn: None,
1713            partition_filter: None,
1714            hybrid: None,
1715        };
1716
1717        let results =
1718            search_multimodal(&table, &queries, config, catalog, store, FusionMethod::Rrf)
1719                .await
1720                .unwrap();
1721
1722        assert!(!results.is_empty(), "should return results");
1723        assert!(results[0].distance <= 0.0, "distance is -rrf_score");
1724        // Row 0 is nearest to both q_text=[1,0,0,0] and q_img=[1,0]
1725        assert_eq!(results[0].row_id.as_u64(), 0, "row 0 should rank first");
1726    }
1727}