1use std::sync::Arc;
3
4use rayon::prelude::*;
5use tracing::{debug, error};
6
7use ailake_catalog::{CatalogProvider, DataFileEntry, IndexStatus, TableIdent};
8use ailake_core::{AilakeError, AilakeResult, EmbeddingModelInfo, RowId, VectorMetric};
9use ailake_file::AilakeFileReader;
10use ailake_index::AnyIndex;
11use ailake_store::Store;
12use ailake_vec::exact_distance;
13use arrow_array::RecordBatch;
14use bytes::Bytes;
15
16use crate::pruner::VectorPruner;
17
18#[derive(Debug, Clone)]
19pub struct SearchConfig {
20 pub top_k: usize,
21 pub ef_search: usize,
22 pub pruning_threshold: f32,
26 pub rerank_factor: Option<usize>,
31}
32
33impl Default for SearchConfig {
34 fn default() -> Self {
35 Self {
36 top_k: 10,
37 ef_search: 50,
38 pruning_threshold: f32::INFINITY,
39 rerank_factor: None,
40 }
41 }
42}
43
44impl SearchConfig {
45 pub fn with_pruning(mut self, threshold: f32) -> Self {
46 self.pruning_threshold = threshold;
47 self
48 }
49
50 pub fn with_reranking(mut self, factor: usize) -> Self {
51 self.rerank_factor = Some(factor);
52 self
53 }
54}
55
56#[derive(Debug)]
57pub struct SearchResult {
58 pub row_id: RowId,
59 pub distance: f32,
60 pub file_path: String,
61}
62
63pub async fn search(
71 table: &TableIdent,
72 query: &[f32],
73 config: SearchConfig,
74 vector_column: &str,
75 dim: u32,
76 catalog: Arc<dyn CatalogProvider>,
77 store: Arc<dyn Store>,
78) -> AilakeResult<Vec<SearchResult>> {
79 let all_files = catalog.list_files(table, None).await?;
81
82 let table_meta = catalog.load_table(table).await?;
84
85 let primary_col = table_meta
90 .properties
91 .get("ailake.vector-column")
92 .map(String::as_str)
93 .unwrap_or("");
94 let stored_dim_key = if vector_column == primary_col {
95 "ailake.vector-dim".to_string()
96 } else {
97 format!("ailake.dim-{vector_column}")
98 };
99 if let Some(table_dim_str) = table_meta.properties.get(&stored_dim_key) {
100 if let Ok(table_dim) = table_dim_str.parse::<u32>() {
101 let query_dim = query.len() as u32;
102 if query_dim != table_dim {
103 let table_model = table_meta
104 .properties
105 .get(EmbeddingModelInfo::property_key())
106 .cloned()
107 .unwrap_or_else(|| format!("dim={}", table_dim));
108 return Err(AilakeError::ModelMismatch {
109 table_model,
110 table_dim,
111 batch_model: format!("query dim={}", query_dim),
112 batch_dim: query_dim,
113 });
114 }
115 }
116 }
117
118 let metric_key = if vector_column == primary_col {
120 "ailake.vector-metric".to_string()
121 } else {
122 format!("ailake.metric-{vector_column}")
123 };
124 let metric = parse_metric(
125 table_meta
126 .properties
127 .get(&metric_key)
128 .or_else(|| table_meta.properties.get("ailake.vector-metric"))
129 .map(String::as_str)
130 .unwrap_or("cosine"),
131 );
132
133 let total_files = all_files.len();
135 let surviving_files = VectorPruner::prune(all_files, query, metric, config.pruning_threshold);
136 debug!(
137 "ailake: geometric pruning — {}/{} files survive (threshold={})",
138 surviving_files.len(),
139 total_files,
140 config.pruning_threshold
141 );
142
143 let candidate_k = match config.rerank_factor {
144 Some(factor) => config.top_k * factor,
145 None => config.top_k,
146 };
147
148 let mut all_results: Vec<SearchResult> = Vec::new();
149
150 for file_entry in &surviving_files {
151 let file_bytes: Bytes = store.get(&file_entry.path).await?;
152 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
153
154 if file_entry.index_status == IndexStatus::Indexing || !reader.is_ailake_file() {
155 debug!(
157 "ailake: flat scan fallback for {} (index_status={:?})",
158 file_entry.path, file_entry.index_status
159 );
160 let (_, raw_vectors) = reader.read_parquet()?;
161 for (row_id, distance) in flat_search(&raw_vectors, query, candidate_k, metric) {
162 all_results.push(SearchResult {
163 row_id,
164 distance,
165 file_path: file_entry.path.clone(),
166 });
167 }
168 continue;
169 }
170
171 let index = reader.load_any_index_for_column(vector_column)?;
172 let local_results = index.search(query, candidate_k, config.ef_search);
173
174 if config.rerank_factor.is_some() {
175 let (_, raw_vectors) = reader.read_parquet()?;
177 for (row_id, _approx_dist) in local_results {
178 let idx = row_id.as_u64() as usize;
179 let exact_dist = match raw_vectors.get(idx) {
180 Some(v) => exact_distance(metric, query, v),
181 None => {
182 error!(
183 "ailake: invariant violated — row_id {} out of bounds \
184 (raw_vectors.len={}, file={}); \
185 Parquet row count and HNSW node count are out of sync; \
186 file may be corrupt — run compaction to rebuild",
187 idx,
188 raw_vectors.len(),
189 file_entry.path
190 );
191 f32::INFINITY
192 }
193 };
194 all_results.push(SearchResult {
195 row_id,
196 distance: exact_dist,
197 file_path: file_entry.path.clone(),
198 });
199 }
200 } else {
201 for (row_id, distance) in local_results {
202 all_results.push(SearchResult {
203 row_id,
204 distance,
205 file_path: file_entry.path.clone(),
206 });
207 }
208 }
209 }
210
211 all_results.sort_by(|a, b| {
213 a.distance
214 .partial_cmp(&b.distance)
215 .unwrap_or(std::cmp::Ordering::Equal)
216 });
217 all_results.truncate(config.top_k);
218 Ok(all_results)
219}
220
221#[derive(Debug, Clone)]
223pub struct ModalQuery<'a> {
224 pub column: &'a str,
226 pub query: &'a [f32],
228 pub weight: f32,
231 pub dim: u32,
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
238pub enum FusionMethod {
239 Rrf,
243}
244
245pub async fn search_multimodal(
255 table: &TableIdent,
256 queries: &[ModalQuery<'_>],
257 config: SearchConfig,
258 catalog: Arc<dyn CatalogProvider>,
259 store: Arc<dyn Store>,
260 fusion: FusionMethod,
261) -> AilakeResult<Vec<SearchResult>> {
262 use std::collections::HashMap;
263
264 if queries.is_empty() {
265 return Err(AilakeError::InvalidArgument(
266 "search_multimodal requires at least one ModalQuery".into(),
267 ));
268 }
269
270 let table_meta = catalog.load_table(table).await?;
272 let primary_col = table_meta
273 .properties
274 .get("ailake.vector-column")
275 .cloned()
276 .unwrap_or_default();
277 let primary_dim: u32 = table_meta
278 .properties
279 .get("ailake.vector-dim")
280 .and_then(|s| s.parse().ok())
281 .unwrap_or(0);
282
283 let per_col_k = (config.top_k * queries.len().max(2)).min(1000);
285
286 let mut per_col_results: Vec<(f32, Vec<SearchResult>)> = Vec::with_capacity(queries.len());
287 for mq in queries {
288 let resolved_dim = if mq.dim > 0 {
290 mq.dim
291 } else if mq.column == primary_col {
292 primary_dim
293 } else {
294 table_meta
295 .properties
296 .get(&format!("ailake.dim-{}", mq.column))
297 .and_then(|s| s.parse().ok())
298 .unwrap_or(mq.query.len() as u32)
299 };
300
301 let col_config = SearchConfig {
302 top_k: per_col_k,
303 ef_search: config.ef_search,
304 pruning_threshold: config.pruning_threshold,
305 rerank_factor: config.rerank_factor,
306 };
307 let results = search(
308 table,
309 mq.query,
310 col_config,
311 mq.column,
312 resolved_dim,
313 catalog.clone(),
314 store.clone(),
315 )
316 .await?;
317 per_col_results.push((mq.weight, results));
318 }
319
320 const K: f32 = 60.0;
322 let mut scores: HashMap<(String, u64), f32> = HashMap::new();
323
324 for (weight, results) in &per_col_results {
325 for (rank, r) in results.iter().enumerate() {
326 let key = (r.file_path.clone(), r.row_id.as_u64());
327 let rrf = weight / (K + rank as f32 + 1.0);
328 *scores.entry(key).or_insert(0.0) += rrf;
329 }
330 }
331
332 let all_files = catalog.list_files(table, None).await?;
335 let _ = all_files; let mut seen: HashMap<(String, u64), f32> = HashMap::new();
339 for (_, results) in &per_col_results {
340 for r in results {
341 let key = (r.file_path.clone(), r.row_id.as_u64());
342 let rrf_score = *scores.get(&key).unwrap_or(&0.0);
343 seen.entry(key).or_insert(rrf_score);
344 }
345 }
346
347 let mut fused: Vec<SearchResult> = seen
348 .into_iter()
349 .map(|((file_path, row_id_u64), rrf_score)| SearchResult {
350 row_id: RowId::new(row_id_u64),
351 distance: -rrf_score,
352 file_path,
353 })
354 .collect();
355
356 fused.sort_by(|a, b| {
357 a.distance
358 .partial_cmp(&b.distance)
359 .unwrap_or(std::cmp::Ordering::Equal)
360 });
361 fused.truncate(config.top_k);
362
363 let _ = fusion; Ok(fused)
366}
367
368fn flat_search(
370 raw: &[Vec<f32>],
371 query: &[f32],
372 top_k: usize,
373 metric: VectorMetric,
374) -> Vec<(RowId, f32)> {
375 let mut results: Vec<(RowId, f32)> = raw
376 .iter()
377 .enumerate()
378 .map(|(i, v)| (RowId::new(i as u64), exact_distance(metric, query, v)))
379 .collect();
380 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
381 results.truncate(top_k);
382 results
383}
384
385fn parse_metric(s: &str) -> VectorMetric {
386 match s {
387 "euclidean" => VectorMetric::Euclidean,
388 "dotproduct" | "dot_product" | "dot" => VectorMetric::DotProduct,
389 _ => VectorMetric::Cosine,
390 }
391}
392
393pub struct SearchSession {
398 shards: Vec<LoadedShard>,
399 metric: VectorMetric,
400}
401
402struct LoadedShard {
403 entry: DataFileEntry,
404 index: Option<AnyIndex>,
406 raw_vectors: Option<Vec<Vec<f32>>>,
409}
410
411impl SearchSession {
412 pub async fn load(
418 table: &TableIdent,
419 vector_column: &str,
420 dim: u32,
421 catalog: Arc<dyn CatalogProvider>,
422 store: Arc<dyn Store>,
423 load_raw: bool,
424 ) -> AilakeResult<Self> {
425 let all_files = catalog.list_files(table, None).await?;
426 let table_meta = catalog.load_table(table).await?;
427 let metric = parse_metric(
428 table_meta
429 .properties
430 .get("ailake.vector-metric")
431 .map(String::as_str)
432 .unwrap_or("cosine"),
433 );
434
435 let mut shards = Vec::with_capacity(all_files.len());
436 for entry in all_files {
437 let file_bytes: Bytes = store.get(&entry.path).await?;
438 let reader = AilakeFileReader::new(file_bytes, vector_column, dim);
439
440 if entry.index_status == IndexStatus::Indexing {
441 let (_, raw_vecs) = reader.read_parquet()?;
443 shards.push(LoadedShard {
444 entry,
445 index: None,
446 raw_vectors: Some(raw_vecs),
447 });
448 } else if reader.is_ailake_file() {
449 let mut index = reader.load_any_index_for_column(vector_column)?;
450 let raw_vectors = if load_raw {
451 index.quantize_to_f16();
452 let (_, vecs) = reader.read_parquet()?;
453 Some(vecs)
454 } else {
455 None
456 };
457 shards.push(LoadedShard {
458 entry,
459 index: Some(index),
460 raw_vectors,
461 });
462 }
463 }
464
465 Ok(Self { shards, metric })
466 }
467
468 pub fn shard_count(&self) -> usize {
470 self.shards.len()
471 }
472
473 pub fn search_batch(
482 &self,
483 queries: &[Vec<f32>],
484 config: &SearchConfig,
485 ) -> Vec<Vec<SearchResult>> {
486 if queries.is_empty() {
487 return vec![];
488 }
489
490 let n_queries = queries.len();
491 let candidate_k = match config.rerank_factor {
492 Some(factor) => config.top_k * factor,
493 None => config.top_k,
494 };
495 let use_nvidia = ailake_index::hardware::detect_cuda();
496 let use_amd = ailake_index::hardware::detect_rocm();
497
498 let mut all_results: Vec<Vec<SearchResult>> = (0..n_queries).map(|_| Vec::new()).collect();
500
501 for shard in &self.shards {
502 if let Some(raw) = &shard.raw_vectors {
503 if !raw.is_empty() {
505 let dim = raw[0].len();
506 let flat: Vec<f32> = raw.iter().flat_map(|v| v.iter().copied()).collect();
507 let row_ids: Vec<u64> = (0..raw.len() as u64).collect();
508 let q_refs: Vec<&[f32]> = queries.iter().map(|q| q.as_slice()).collect();
509
510 let gpu_batch = if use_nvidia {
511 ailake_index::gpu::try_nvidia_search_batch(
512 &q_refs,
513 &row_ids,
514 &flat,
515 dim,
516 self.metric,
517 candidate_k,
518 )
519 } else if use_amd {
520 ailake_index::gpu::try_rocm_search_batch(
521 &q_refs,
522 &row_ids,
523 &flat,
524 dim,
525 self.metric,
526 candidate_k,
527 )
528 } else {
529 None
530 };
531
532 if let Some(batch) = gpu_batch {
533 for (qi, results) in batch.into_iter().enumerate() {
534 for (row_id, distance) in results {
535 all_results[qi].push(SearchResult {
536 row_id,
537 distance,
538 file_path: shard.entry.path.clone(),
539 });
540 }
541 }
542 continue;
543 }
544 }
545
546 for (qi, query) in queries.iter().enumerate() {
548 for (row_id, distance) in flat_search(raw, query, candidate_k, self.metric) {
549 all_results[qi].push(SearchResult {
550 row_id,
551 distance,
552 file_path: shard.entry.path.clone(),
553 });
554 }
555 }
556 } else if let Some(index) = &shard.index {
557 let shard_results: Vec<Vec<SearchResult>> = queries
559 .par_iter()
560 .map(|query| {
561 index
562 .search(query, candidate_k, config.ef_search)
563 .into_iter()
564 .map(|(row_id, distance)| SearchResult {
565 row_id,
566 distance,
567 file_path: shard.entry.path.clone(),
568 })
569 .collect()
570 })
571 .collect();
572
573 for (qi, results) in shard_results.into_iter().enumerate() {
574 all_results[qi].extend(results);
575 }
576 }
577 }
578
579 for results in &mut all_results {
581 results.sort_by(|a, b| {
582 a.distance
583 .partial_cmp(&b.distance)
584 .unwrap_or(std::cmp::Ordering::Equal)
585 });
586 results.truncate(config.top_k);
587 }
588
589 all_results
590 }
591
592 pub fn search_query(&self, query: &[f32], config: &SearchConfig) -> Vec<SearchResult> {
594 let candidate_k = match config.rerank_factor {
595 Some(factor) => config.top_k * factor,
596 None => config.top_k,
597 };
598
599 let mut all_results: Vec<SearchResult> = self
600 .shards
601 .par_iter()
602 .flat_map(|shard| {
603 if let Some(centroid) = ailake_catalog::decode_centroid(&shard.entry, self.metric) {
605 let dist = match self.metric {
606 VectorMetric::Cosine | VectorMetric::NormalizedCosine => {
607 ailake_vec::cosine_distance(query, ¢roid.values)
608 }
609 VectorMetric::Euclidean => {
610 ailake_vec::euclidean_distance(query, ¢roid.values)
611 }
612 VectorMetric::DotProduct => {
613 -ailake_vec::dot_product(query, ¢roid.values)
614 }
615 };
616 if dist - centroid.radius > config.pruning_threshold {
617 return vec![];
618 }
619 }
620
621 if let Some(index) = &shard.index {
622 let local_results = index.search(query, candidate_k, config.ef_search);
624 if config.rerank_factor.is_some() {
625 if let Some(raw) = &shard.raw_vectors {
626 local_results
627 .into_iter()
628 .map(|(row_id, _approx_dist)| {
629 let idx = row_id.as_u64() as usize;
630 let exact_dist = raw
631 .get(idx)
632 .map(|v| exact_distance(self.metric, query, v))
633 .unwrap_or(f32::INFINITY);
634 SearchResult {
635 row_id,
636 distance: exact_dist,
637 file_path: shard.entry.path.clone(),
638 }
639 })
640 .collect()
641 } else {
642 local_results
643 .into_iter()
644 .map(|(row_id, distance)| SearchResult {
645 row_id,
646 distance,
647 file_path: shard.entry.path.clone(),
648 })
649 .collect()
650 }
651 } else {
652 local_results
653 .into_iter()
654 .map(|(row_id, distance)| SearchResult {
655 row_id,
656 distance,
657 file_path: shard.entry.path.clone(),
658 })
659 .collect()
660 }
661 } else if let Some(raw) = &shard.raw_vectors {
662 flat_search(raw, query, candidate_k, self.metric)
664 .into_iter()
665 .map(|(row_id, distance)| SearchResult {
666 row_id,
667 distance,
668 file_path: shard.entry.path.clone(),
669 })
670 .collect()
671 } else {
672 vec![]
673 }
674 })
675 .collect();
676
677 all_results.sort_by(|a, b| {
678 a.distance
679 .partial_cmp(&b.distance)
680 .unwrap_or(std::cmp::Ordering::Equal)
681 });
682 all_results.truncate(config.top_k);
683 all_results
684 }
685}
686
687pub async fn fetch_rows(
696 results: &[SearchResult],
697 store: Arc<dyn Store>,
698 vector_column: &str,
699 dim: u32,
700) -> AilakeResult<RecordBatch> {
701 use std::collections::HashMap;
702
703 use arrow_array::{ArrayRef, Float32Array, UInt32Array};
704 use arrow_schema::{DataType, Field, Schema};
705 use arrow_select::{concat::concat_batches, take::take};
706
707 if results.is_empty() {
708 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
709 }
710
711 let mut by_file: HashMap<&str, Vec<(u64, f32, usize)>> = HashMap::new();
713 for (i, r) in results.iter().enumerate() {
714 by_file
715 .entry(r.file_path.as_str())
716 .or_default()
717 .push((r.row_id.as_u64(), r.distance, i));
718 }
719
720 use arrow_array::FixedSizeListArray;
721
722 let mut collected: Vec<(usize, f32, RecordBatch, Vec<f32>)> = Vec::with_capacity(results.len());
724
725 for (file_path, rows) in &by_file {
726 let bytes = store.get(file_path).await?;
727 let reader = AilakeFileReader::new(bytes, vector_column, dim);
728 let (batch, vectors) = reader.read_parquet()?;
729
730 for &(row_id, distance, pos) in rows {
731 let idx = row_id as usize;
732 if idx >= batch.num_rows() {
733 tracing::warn!(
734 "fetch_rows: row_id {} out of bounds (file_rows={}, file={}), skipping",
735 idx,
736 batch.num_rows(),
737 file_path
738 );
739 continue;
740 }
741
742 let indices = UInt32Array::from(vec![idx as u32]);
743 let row_cols: Vec<ArrayRef> = batch
744 .columns()
745 .iter()
746 .map(|col| {
747 take(col.as_ref(), &indices, None)
748 .map_err(|e| AilakeError::Arrow(e.to_string()))
749 })
750 .collect::<AilakeResult<Vec<_>>>()?;
751
752 let row_batch = RecordBatch::try_new(batch.schema(), row_cols)
753 .map_err(|e| AilakeError::Arrow(e.to_string()))?;
754
755 let vec = vectors
757 .get(idx)
758 .cloned()
759 .unwrap_or_else(|| vec![0.0f32; dim as usize]);
760
761 collected.push((pos, distance, row_batch, vec));
762 }
763 }
764
765 if collected.is_empty() {
766 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
767 }
768
769 collected.sort_by_key(|(pos, _, _, _)| *pos);
771
772 let distances: Vec<f32> = collected.iter().map(|(_, d, _, _)| *d).collect();
773 let row_batches: Vec<&RecordBatch> = collected.iter().map(|(_, _, b, _)| b).collect();
774 let base_schema = collected[0].2.schema();
775
776 let combined =
777 concat_batches(&base_schema, row_batches).map_err(|e| AilakeError::Arrow(e.to_string()))?;
778
779 let flat_vecs: Vec<f32> = collected
781 .iter()
782 .flat_map(|(_, _, _, v)| v.iter().copied())
783 .collect();
784 let item_field = Arc::new(Field::new("item", DataType::Float32, false));
785 let values_arr = Arc::new(Float32Array::from(flat_vecs)) as ArrayRef;
786 let vec_col = FixedSizeListArray::new(item_field.clone(), dim as i32, values_arr, None);
787 let vec_field = Arc::new(Field::new(
788 vector_column,
789 DataType::FixedSizeList(item_field, dim as i32),
790 false,
791 ));
792
793 let mut fields: Vec<Arc<Field>> = base_schema.fields().to_vec();
795 fields.push(vec_field);
796 fields.push(Arc::new(Field::new("_distance", DataType::Float32, false)));
797 let new_schema = Arc::new(Schema::new(fields));
798
799 let mut columns: Vec<ArrayRef> = combined.columns().to_vec();
800 columns.push(Arc::new(vec_col));
801 columns.push(Arc::new(Float32Array::from(distances)));
802
803 RecordBatch::try_new(new_schema, columns).map_err(|e| AilakeError::Arrow(e.to_string()))
804}
805
806#[cfg(test)]
807mod tests {
808 use super::*;
809 use crate::writer::MultiVectorBatch;
810 use ailake_catalog::{HadoopCatalog, TableIdent};
811 use ailake_core::{VectorMetric, VectorPrecision, VectorStoragePolicy};
812 use ailake_store::LocalStore;
813 use arrow_array::{Int32Array, RecordBatch};
814 use arrow_schema::{DataType, Field, Schema};
815 use std::sync::Arc;
816 use tempfile::TempDir;
817
818 fn make_policy(dim: u32) -> VectorStoragePolicy {
819 VectorStoragePolicy {
820 column_name: "embedding".to_string(),
821 dim,
822 metric: VectorMetric::Cosine,
823 precision: VectorPrecision::F16,
824 pq: None,
825 keep_raw_for_reranking: true,
826 pre_normalize: false,
827 hnsw_m: None,
828 hnsw_ef_construction: None,
829 ivf_residual: false,
830 embedding_model: None,
831 modality: None,
832 }
833 }
834
835 async fn write_demo_table(dir: &TempDir, dim: usize, rows: usize) {
836 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
837 let catalog = Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
838 let table = TableIdent::new("default", "table");
839
840 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
841 let ids: Vec<i32> = (0..rows as i32).collect();
842 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
843
844 let embeddings: Vec<Vec<f32>> = (0..rows)
846 .map(|i| {
847 let mut v = vec![0.0f32; dim];
848 v[i % dim] = 1.0;
849 v
850 })
851 .collect();
852
853 let mut writer =
854 crate::TableWriter::create_or_open(catalog, store, make_policy(dim as u32), table)
855 .await
856 .unwrap();
857 writer.write_batch(&batch, &embeddings).await.unwrap();
858 writer.commit().await.unwrap();
859 }
860
861 #[tokio::test]
862 async fn rerank_returns_correct_top_k_count() {
863 let dir = TempDir::new().unwrap();
864 let dim = 8usize;
865 write_demo_table(&dir, dim, 8).await;
866
867 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
868 let catalog: Arc<dyn CatalogProvider> =
869 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
870 let table = TableIdent::new("default", "table");
871
872 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
873 let config = SearchConfig {
874 top_k: 3,
875 ef_search: 50,
876 pruning_threshold: f32::INFINITY,
877 rerank_factor: Some(2),
878 };
879
880 let results = search(
881 &table,
882 &query,
883 config,
884 "embedding",
885 dim as u32,
886 catalog,
887 store,
888 )
889 .await
890 .unwrap();
891
892 assert_eq!(results.len(), 3);
893 }
894
895 #[tokio::test]
896 async fn rerank_nearest_is_exact_match() {
897 let dir = TempDir::new().unwrap();
898 let dim = 8usize;
899 write_demo_table(&dir, dim, 8).await;
900
901 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
902 let catalog: Arc<dyn CatalogProvider> =
903 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
904 let table = TableIdent::new("default", "table");
905
906 let query = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
908 let config = SearchConfig {
909 top_k: 1,
910 ef_search: 50,
911 pruning_threshold: f32::INFINITY,
912 rerank_factor: Some(4),
913 };
914
915 let results = search(
916 &table,
917 &query,
918 config,
919 "embedding",
920 dim as u32,
921 catalog,
922 store,
923 )
924 .await
925 .unwrap();
926
927 assert_eq!(results.len(), 1);
928 assert!(
930 results[0].distance < 1e-3,
931 "distance was {}",
932 results[0].distance
933 );
934 assert_eq!(results[0].row_id, RowId::new(0));
935 }
936
937 #[tokio::test]
938 async fn no_rerank_matches_default_behavior() {
939 let dir = TempDir::new().unwrap();
940 let dim = 4usize;
941 write_demo_table(&dir, dim, 4).await;
942
943 let store_a: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
944 let store_b: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
945 let cat_a: Arc<dyn CatalogProvider> =
946 Arc::new(HadoopCatalog::new(store_a.clone(), "warehouse"));
947 let cat_b: Arc<dyn CatalogProvider> =
948 Arc::new(HadoopCatalog::new(store_b.clone(), "warehouse"));
949 let table = TableIdent::new("default", "table");
950
951 let query = vec![1.0f32, 0.0, 0.0, 0.0];
952 let cfg_plain = SearchConfig {
953 top_k: 2,
954 ef_search: 50,
955 pruning_threshold: f32::INFINITY,
956 rerank_factor: None,
957 };
958 let cfg_rerank = SearchConfig {
959 top_k: 2,
960 ef_search: 50,
961 pruning_threshold: f32::INFINITY,
962 rerank_factor: Some(2),
963 };
964
965 let plain = search(
966 &table,
967 &query,
968 cfg_plain,
969 "embedding",
970 dim as u32,
971 cat_a,
972 store_a,
973 )
974 .await
975 .unwrap();
976 let reranked = search(
977 &table,
978 &query,
979 cfg_rerank,
980 "embedding",
981 dim as u32,
982 cat_b,
983 store_b,
984 )
985 .await
986 .unwrap();
987
988 assert_eq!(plain[0].row_id, reranked[0].row_id);
990 }
991
992 #[tokio::test]
993 async fn multimodal_rrf_returns_top_k() {
994 let dir = TempDir::new().unwrap();
995 let dim = 4usize;
996 write_demo_table(&dir, dim, 4).await;
997
998 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
999 let catalog: Arc<dyn CatalogProvider> =
1000 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
1001 let table = TableIdent::new("default", "table");
1002
1003 let q1 = vec![1.0f32, 0.0, 0.0, 0.0];
1006 let q2 = vec![0.0f32, 1.0, 0.0, 0.0];
1007
1008 let queries = vec![
1009 ModalQuery {
1010 column: "embedding",
1011 query: &q1,
1012 weight: 0.7,
1013 dim: dim as u32,
1014 },
1015 ModalQuery {
1016 column: "embedding",
1017 query: &q2,
1018 weight: 0.3,
1019 dim: dim as u32,
1020 },
1021 ];
1022
1023 let config = SearchConfig {
1024 top_k: 2,
1025 ef_search: 50,
1026 pruning_threshold: f32::INFINITY,
1027 rerank_factor: None,
1028 };
1029
1030 let results =
1031 search_multimodal(&table, &queries, config, catalog, store, FusionMethod::Rrf)
1032 .await
1033 .unwrap();
1034
1035 assert_eq!(results.len(), 2);
1036 assert!(results[0].distance <= 0.0);
1038 assert!(results[0].row_id.as_u64() < 4);
1040 }
1041
1042 #[tokio::test]
1046 async fn multimodal_rrf_cross_modal_different_dims() {
1047 let dir = TempDir::new().unwrap();
1048 let store: Arc<dyn Store> = Arc::new(LocalStore::new(dir.path()));
1049 let catalog: Arc<dyn CatalogProvider> =
1050 Arc::new(HadoopCatalog::new(store.clone(), "warehouse"));
1051 let table = TableIdent::new("default", "table");
1052
1053 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
1055 let rows = 4usize;
1056 let ids: Vec<i32> = (0..rows as i32).collect();
1057 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(ids))]).unwrap();
1058
1059 let text_embs: Vec<Vec<f32>> = (0..rows)
1060 .map(|i| {
1061 let mut v = vec![0.0f32; 4];
1062 v[i % 4] = 1.0;
1063 v
1064 })
1065 .collect();
1066 let img_embs: Vec<Vec<f32>> = (0..rows)
1067 .map(|i| {
1068 let mut v = vec![0.0f32; 2];
1069 v[i % 2] = 1.0;
1070 v
1071 })
1072 .collect();
1073
1074 let text_policy = make_policy(4);
1075 let img_policy = VectorStoragePolicy {
1076 column_name: "img_embedding".to_string(),
1077 dim: 2,
1078 metric: VectorMetric::Cosine,
1079 precision: VectorPrecision::F16,
1080 pq: None,
1081 keep_raw_for_reranking: true,
1082 pre_normalize: false,
1083 hnsw_m: None,
1084 hnsw_ef_construction: None,
1085 ivf_residual: false,
1086 embedding_model: None,
1087 modality: None,
1088 };
1089
1090 let mut writer = crate::TableWriter::create_or_open(
1091 catalog.clone(),
1092 store.clone(),
1093 text_policy,
1094 table.clone(),
1095 )
1096 .await
1097 .unwrap();
1098
1099 let batches = [
1100 MultiVectorBatch {
1101 policy: make_policy(4),
1102 embeddings: &text_embs,
1103 },
1104 MultiVectorBatch {
1105 policy: img_policy,
1106 embeddings: &img_embs,
1107 },
1108 ];
1109 writer.write_batch_multi(&batch, &batches).await.unwrap();
1110 writer.commit().await.unwrap();
1111
1112 let q_text = vec![1.0f32, 0.0, 0.0, 0.0];
1114 let q_img = vec![1.0f32, 0.0];
1115
1116 let queries = vec![
1117 ModalQuery {
1118 column: "embedding",
1119 query: &q_text,
1120 weight: 0.6,
1121 dim: 4,
1122 },
1123 ModalQuery {
1124 column: "img_embedding",
1125 query: &q_img,
1126 weight: 0.4,
1127 dim: 2,
1128 },
1129 ];
1130 let config = SearchConfig {
1131 top_k: 2,
1132 ef_search: 50,
1133 pruning_threshold: f32::INFINITY,
1134 rerank_factor: None,
1135 };
1136
1137 let results =
1138 search_multimodal(&table, &queries, config, catalog, store, FusionMethod::Rrf)
1139 .await
1140 .unwrap();
1141
1142 assert!(!results.is_empty(), "should return results");
1143 assert!(results[0].distance <= 0.0, "distance is -rrf_score");
1144 assert_eq!(results[0].row_id.as_u64(), 0, "row 0 should rank first");
1146 }
1147}