Skip to main content

lance_index/scalar/inverted/
index.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::fmt::{Debug, Display};
5use std::sync::Arc;
6use std::{
7    cmp::{min, Reverse},
8    collections::BinaryHeap,
9};
10use std::{collections::HashMap, ops::Range};
11
12use crate::metrics::NoOpMetricsCollector;
13use crate::prefilter::NoFilter;
14use crate::scalar::registry::{TrainingCriteria, TrainingOrdering};
15use arrow::datatypes::{self, Float32Type, Int32Type, UInt64Type};
16use arrow::{
17    array::{
18        AsArray, LargeBinaryBuilder, ListBuilder, StringBuilder, UInt32Builder, UInt64Builder,
19    },
20    buffer::OffsetBuffer,
21};
22use arrow::{buffer::ScalarBuffer, datatypes::UInt32Type};
23use arrow_array::{
24    Array, ArrayRef, BooleanArray, Float32Array, LargeBinaryArray, ListArray, OffsetSizeTrait,
25    RecordBatch, UInt32Array, UInt64Array,
26};
27use arrow_schema::{DataType, Field, Schema, SchemaRef};
28use async_trait::async_trait;
29use datafusion::execution::SendableRecordBatchStream;
30use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
31use datafusion_common::DataFusionError;
32use deepsize::DeepSizeOf;
33use fst::{Automaton, IntoStreamer, Streamer};
34use futures::{stream, FutureExt, StreamExt, TryStreamExt};
35use itertools::Itertools;
36use lance_arrow::{iter_str_array, RecordBatchExt};
37use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
38use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap};
39use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS};
40use lance_core::{
41    container::list::ExpLinkedList,
42    utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
43};
44use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD};
45use roaring::RoaringBitmap;
46use snafu::location;
47use std::sync::LazyLock;
48use tokio::task::spawn_blocking;
49use tracing::{info, instrument};
50
51use super::{
52    builder::{
53        doc_file_path, inverted_list_schema, posting_file_path, token_file_path, ScoredDoc,
54        BLOCK_SIZE,
55    },
56    iter::PlainPostingListIterator,
57    query::*,
58    scorer::{idf, IndexBM25Scorer, Scorer, B, K1},
59};
60use super::{
61    builder::{InnerBuilder, PositionRecorder},
62    encoding::compress_posting_list,
63    iter::CompressedPostingListIterator,
64};
65use super::{encoding::compress_positions, iter::PostingListIterator};
66use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams};
67use crate::frag_reuse::FragReuseIndex;
68use crate::pbold;
69use crate::scalar::inverted::lance_tokenizer::TextTokenizer;
70use crate::scalar::inverted::scorer::MemBM25Scorer;
71use crate::scalar::inverted::tokenizer::lance_tokenizer::LanceTokenizer;
72use crate::scalar::{
73    AnyQuery, BuiltinIndexType, CreatedIndex, IndexReader, IndexStore, MetricsCollector,
74    ScalarIndex, ScalarIndexParams, SearchResult, TokenQuery, UpdateCriteria,
75};
76use crate::Index;
77use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys};
78use std::str::FromStr;
79
80// Version 0: Arrow TokenSetFormat (legacy)
81// Version 1: Fst TokenSetFormat (new default, incompatible clients < 0.38)
82pub const INVERTED_INDEX_VERSION: u32 = 1;
83pub const TOKENS_FILE: &str = "tokens.lance";
84pub const INVERT_LIST_FILE: &str = "invert.lance";
85pub const DOCS_FILE: &str = "docs.lance";
86pub const METADATA_FILE: &str = "metadata.lance";
87
88pub const TOKEN_COL: &str = "_token";
89pub const TOKEN_ID_COL: &str = "_token_id";
90pub const TOKEN_FST_BYTES_COL: &str = "_token_fst_bytes";
91pub const TOKEN_NEXT_ID_COL: &str = "_token_next_id";
92pub const TOKEN_TOTAL_LENGTH_COL: &str = "_token_total_length";
93pub const FREQUENCY_COL: &str = "_frequency";
94pub const POSITION_COL: &str = "_position";
95pub const COMPRESSED_POSITION_COL: &str = "_compressed_position";
96pub const POSTING_COL: &str = "_posting";
97pub const MAX_SCORE_COL: &str = "_max_score";
98pub const LENGTH_COL: &str = "_length";
99pub const BLOCK_MAX_SCORE_COL: &str = "_block_max_score";
100pub const NUM_TOKEN_COL: &str = "_num_tokens";
101pub const SCORE_COL: &str = "_score";
102pub const TOKEN_SET_FORMAT_KEY: &str = "token_set_format";
103
104pub static SCORE_FIELD: LazyLock<Field> =
105    LazyLock::new(|| Field::new(SCORE_COL, DataType::Float32, true));
106pub static FTS_SCHEMA: LazyLock<SchemaRef> =
107    LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone(), SCORE_FIELD.clone()])));
108static ROW_ID_SCHEMA: LazyLock<SchemaRef> =
109    LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()])));
110
111#[derive(Debug)]
112struct PartitionCandidates {
113    tokens_by_position: Vec<String>,
114    candidates: Vec<DocCandidate>,
115}
116
117impl PartitionCandidates {
118    fn empty() -> Self {
119        Self {
120            tokens_by_position: Vec::new(),
121            candidates: Vec::new(),
122        }
123    }
124}
125
126#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)]
127pub enum TokenSetFormat {
128    Arrow,
129    #[default]
130    Fst,
131}
132
133impl Display for TokenSetFormat {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        match self {
136            Self::Arrow => f.write_str("arrow"),
137            Self::Fst => f.write_str("fst"),
138        }
139    }
140}
141
142impl FromStr for TokenSetFormat {
143    type Err = Error;
144
145    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
146        match s.trim() {
147            "" => Ok(Self::Arrow),
148            "arrow" => Ok(Self::Arrow),
149            "fst" => Ok(Self::Fst),
150            other => Err(Error::Index {
151                message: format!("unsupported token set format {}", other),
152                location: location!(),
153            }),
154        }
155    }
156}
157
158impl DeepSizeOf for TokenSetFormat {
159    fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
160        0
161    }
162}
163
164#[derive(Clone)]
165pub struct InvertedIndex {
166    params: InvertedIndexParams,
167    store: Arc<dyn IndexStore>,
168    tokenizer: Box<dyn LanceTokenizer>,
169    token_set_format: TokenSetFormat,
170    pub(crate) partitions: Vec<Arc<InvertedPartition>>,
171}
172
173impl Debug for InvertedIndex {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("InvertedIndex")
176            .field("params", &self.params)
177            .field("token_set_format", &self.token_set_format)
178            .field("partitions", &self.partitions)
179            .finish()
180    }
181}
182
183impl DeepSizeOf for InvertedIndex {
184    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
185        self.partitions.deep_size_of_children(context)
186    }
187}
188
189impl InvertedIndex {
190    fn to_builder(&self) -> InvertedIndexBuilder {
191        self.to_builder_with_offset(None)
192    }
193
194    fn to_builder_with_offset(&self, fragment_mask: Option<u64>) -> InvertedIndexBuilder {
195        if self.is_legacy() {
196            // for legacy format, we re-create the index in the new format
197            InvertedIndexBuilder::from_existing_index(
198                self.params.clone(),
199                None,
200                Vec::new(),
201                self.token_set_format,
202                fragment_mask,
203            )
204        } else {
205            let partitions = match fragment_mask {
206                Some(fragment_mask) => self
207                    .partitions
208                    .iter()
209                    // Filter partitions that belong to the specified fragment
210                    // The mask contains fragment_id in high 32 bits, we check if partition's
211                    // fragment_id matches by comparing the masked result with the original mask
212                    .filter(|part| part.belongs_to_fragment(fragment_mask))
213                    .map(|part| part.id())
214                    .collect(),
215                None => self.partitions.iter().map(|part| part.id()).collect(),
216            };
217
218            InvertedIndexBuilder::from_existing_index(
219                self.params.clone(),
220                Some(self.store.clone()),
221                partitions,
222                self.token_set_format,
223                fragment_mask,
224            )
225        }
226    }
227
228    pub fn tokenizer(&self) -> Box<dyn LanceTokenizer> {
229        self.tokenizer.clone()
230    }
231
232    pub fn params(&self) -> &InvertedIndexParams {
233        &self.params
234    }
235
236    /// Returns the number of partitions in this inverted index.
237    pub fn partition_count(&self) -> usize {
238        self.partitions.len()
239    }
240
241    // search the documents that contain the query
242    // return the row ids of the documents sorted by bm25 score
243    // ref: https://en.wikipedia.org/wiki/Okapi_BM25
244    // we first calculate in-partition BM25 scores,
245    // then re-calculate the scores for the top k documents across all partitions
246    #[instrument(level = "debug", skip_all)]
247    pub async fn bm25_search(
248        &self,
249        tokens: Arc<Tokens>,
250        params: Arc<FtsSearchParams>,
251        operator: Operator,
252        prefilter: Arc<dyn PreFilter>,
253        metrics: Arc<dyn MetricsCollector>,
254    ) -> Result<(Vec<u64>, Vec<f32>)> {
255        let limit = params.limit.unwrap_or(usize::MAX);
256        if limit == 0 {
257            return Ok((Vec::new(), Vec::new()));
258        }
259        let mask = prefilter.mask();
260
261        let mut candidates = BinaryHeap::new();
262        let parts = self
263            .partitions
264            .iter()
265            .map(|part| {
266                let part = part.clone();
267                let tokens = tokens.clone();
268                let params = params.clone();
269                let mask = mask.clone();
270                let metrics = metrics.clone();
271                async move {
272                    let postings = part
273                        .load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref())
274                        .await?;
275                    if postings.is_empty() {
276                        return Ok(PartitionCandidates::empty());
277                    }
278                    let mut tokens_by_position = vec![String::new(); postings.len()];
279                    for posting in &postings {
280                        let idx = posting.term_index() as usize;
281                        tokens_by_position[idx] = posting.token().to_owned();
282                    }
283                    let params = params.clone();
284                    let mask = mask.clone();
285                    let metrics = metrics.clone();
286                    spawn_cpu(move || {
287                        let candidates = part.bm25_search(
288                            params.as_ref(),
289                            operator,
290                            mask,
291                            postings,
292                            metrics.as_ref(),
293                        )?;
294                        Ok(PartitionCandidates {
295                            tokens_by_position,
296                            candidates,
297                        })
298                    })
299                    .await
300                }
301            })
302            .collect::<Vec<_>>();
303        let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
304        let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
305        let mut idf_cache: HashMap<String, f32> = HashMap::new();
306        while let Some(res) = parts.try_next().await? {
307            if res.candidates.is_empty() {
308                continue;
309            }
310            let mut idf_by_position = Vec::with_capacity(res.tokens_by_position.len());
311            for token in &res.tokens_by_position {
312                let idf_weight = match idf_cache.get(token) {
313                    Some(weight) => *weight,
314                    None => {
315                        let weight = scorer.query_weight(token);
316                        idf_cache.insert(token.clone(), weight);
317                        weight
318                    }
319                };
320                idf_by_position.push(idf_weight);
321            }
322            for DocCandidate {
323                row_id,
324                freqs,
325                doc_length,
326            } in res.candidates
327            {
328                let mut score = 0.0;
329                for (term_index, freq) in freqs.into_iter() {
330                    debug_assert!((term_index as usize) < idf_by_position.len());
331                    score +=
332                        idf_by_position[term_index as usize] * scorer.doc_weight(freq, doc_length);
333                }
334                if candidates.len() < limit {
335                    candidates.push(Reverse(ScoredDoc::new(row_id, score)));
336                } else if candidates.peek().unwrap().0.score.0 < score {
337                    candidates.pop();
338                    candidates.push(Reverse(ScoredDoc::new(row_id, score)));
339                }
340            }
341        }
342
343        Ok(candidates
344            .into_sorted_vec()
345            .into_iter()
346            .map(|Reverse(doc)| (doc.row_id, doc.score.0))
347            .unzip())
348    }
349
350    async fn load_legacy_index(
351        store: Arc<dyn IndexStore>,
352        frag_reuse_index: Option<Arc<FragReuseIndex>>,
353        index_cache: &LanceCache,
354    ) -> Result<Arc<Self>> {
355        log::warn!("loading legacy FTS index");
356        let tokens_fut = tokio::spawn({
357            let store = store.clone();
358            async move {
359                let token_reader = store.open_index_file(TOKENS_FILE).await?;
360                let tokenizer = token_reader
361                    .schema()
362                    .metadata
363                    .get("tokenizer")
364                    .map(|s| serde_json::from_str::<InvertedIndexParams>(s))
365                    .transpose()?
366                    .unwrap_or_default();
367                let tokens = TokenSet::load(token_reader, TokenSetFormat::Arrow).await?;
368                Result::Ok((tokenizer, tokens))
369            }
370        });
371        let invert_list_fut = tokio::spawn({
372            let store = store.clone();
373            let index_cache_clone = index_cache.clone();
374            async move {
375                let invert_list_reader = store.open_index_file(INVERT_LIST_FILE).await?;
376                let invert_list =
377                    PostingListReader::try_new(invert_list_reader, &index_cache_clone).await?;
378                Result::Ok(Arc::new(invert_list))
379            }
380        });
381        let docs_fut = tokio::spawn({
382            let store = store.clone();
383            async move {
384                let docs_reader = store.open_index_file(DOCS_FILE).await?;
385                let docs = DocSet::load(docs_reader, true, frag_reuse_index).await?;
386                Result::Ok(docs)
387            }
388        });
389
390        let (tokenizer_config, tokens) = tokens_fut.await??;
391        let inverted_list = invert_list_fut.await??;
392        let docs = docs_fut.await??;
393
394        let tokenizer = tokenizer_config.build()?;
395
396        Ok(Arc::new(Self {
397            params: tokenizer_config,
398            store: store.clone(),
399            tokenizer,
400            token_set_format: TokenSetFormat::Arrow,
401            partitions: vec![Arc::new(InvertedPartition {
402                id: 0,
403                store,
404                tokens,
405                inverted_list,
406                docs,
407                token_set_format: TokenSetFormat::Arrow,
408            })],
409        }))
410    }
411
412    pub fn is_legacy(&self) -> bool {
413        self.partitions.len() == 1 && self.partitions[0].is_legacy()
414    }
415
416    pub async fn load(
417        store: Arc<dyn IndexStore>,
418        frag_reuse_index: Option<Arc<FragReuseIndex>>,
419        index_cache: &LanceCache,
420    ) -> Result<Arc<Self>>
421    where
422        Self: Sized,
423    {
424        // for new index format, there is a metadata file and multiple partitions,
425        // each partition is a separate index containing tokens, inverted list and docs.
426        // for old index format, there is no metadata file, and it's just like a single partition
427
428        match store.open_index_file(METADATA_FILE).await {
429            Ok(reader) => {
430                let params = reader.schema().metadata.get("params").ok_or(Error::Index {
431                    message: "params not found in metadata".to_owned(),
432                    location: location!(),
433                })?;
434                let params = serde_json::from_str::<InvertedIndexParams>(params)?;
435                let partitions =
436                    reader
437                        .schema()
438                        .metadata
439                        .get("partitions")
440                        .ok_or(Error::Index {
441                            message: "partitions not found in metadata".to_owned(),
442                            location: location!(),
443                        })?;
444                let partitions: Vec<u64> = serde_json::from_str(partitions)?;
445                let token_set_format = reader
446                    .schema()
447                    .metadata
448                    .get(TOKEN_SET_FORMAT_KEY)
449                    .map(|name| TokenSetFormat::from_str(name))
450                    .transpose()?
451                    .unwrap_or(TokenSetFormat::Arrow);
452
453                let format = token_set_format;
454                let partitions = partitions.into_iter().map(|id| {
455                    let store = store.clone();
456                    let frag_reuse_index_clone = frag_reuse_index.clone();
457                    let index_cache_for_part =
458                        index_cache.with_key_prefix(format!("part-{}", id).as_str());
459                    let token_set_format = format;
460                    async move {
461                        Result::Ok(Arc::new(
462                            InvertedPartition::load(
463                                store,
464                                id,
465                                frag_reuse_index_clone,
466                                &index_cache_for_part,
467                                token_set_format,
468                            )
469                            .await?,
470                        ))
471                    }
472                });
473                let partitions = stream::iter(partitions)
474                    .buffer_unordered(store.io_parallelism())
475                    .try_collect::<Vec<_>>()
476                    .await?;
477
478                let tokenizer = params.build()?;
479                Ok(Arc::new(Self {
480                    params,
481                    store,
482                    tokenizer,
483                    token_set_format,
484                    partitions,
485                }))
486            }
487            Err(_) => {
488                // old index format
489                Self::load_legacy_index(store, frag_reuse_index, index_cache).await
490            }
491        }
492    }
493}
494
495#[async_trait]
496impl Index for InvertedIndex {
497    fn as_any(&self) -> &dyn std::any::Any {
498        self
499    }
500
501    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
502        self
503    }
504
505    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
506        Err(Error::invalid_input(
507            "inverted index cannot be cast to vector index",
508            location!(),
509        ))
510    }
511
512    fn statistics(&self) -> Result<serde_json::Value> {
513        let num_tokens = self
514            .partitions
515            .iter()
516            .map(|part| part.tokens.len())
517            .sum::<usize>();
518        let num_docs = self
519            .partitions
520            .iter()
521            .map(|part| part.docs.len())
522            .sum::<usize>();
523        Ok(serde_json::json!({
524            "params": self.params,
525            "num_tokens": num_tokens,
526            "num_docs": num_docs,
527        }))
528    }
529
530    async fn prewarm(&self) -> Result<()> {
531        for part in &self.partitions {
532            part.inverted_list.prewarm().await?;
533        }
534        Ok(())
535    }
536
537    fn index_type(&self) -> crate::IndexType {
538        crate::IndexType::Inverted
539    }
540
541    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
542        unimplemented!()
543    }
544}
545
546impl InvertedIndex {
547    /// Search docs match the input text.
548    async fn do_search(&self, text: &str) -> Result<RecordBatch> {
549        let params = FtsSearchParams::new();
550        let mut tokenizer = self.tokenizer.clone();
551        let tokens = collect_query_tokens(text, &mut tokenizer, None);
552
553        let (doc_ids, _) = self
554            .bm25_search(
555                Arc::new(tokens),
556                params.into(),
557                Operator::And,
558                Arc::new(NoFilter),
559                Arc::new(NoOpMetricsCollector),
560            )
561            .boxed()
562            .await?;
563
564        Ok(RecordBatch::try_new(
565            ROW_ID_SCHEMA.clone(),
566            vec![Arc::new(UInt64Array::from(doc_ids))],
567        )?)
568    }
569}
570
571#[async_trait]
572impl ScalarIndex for InvertedIndex {
573    // return the row ids of the documents that contain the query
574    #[instrument(level = "debug", skip_all)]
575    async fn search(
576        &self,
577        query: &dyn AnyQuery,
578        _metrics: &dyn MetricsCollector,
579    ) -> Result<SearchResult> {
580        let query = query.as_any().downcast_ref::<TokenQuery>().unwrap();
581
582        match query {
583            TokenQuery::TokensContains(text) => {
584                let records = self.do_search(text).await?;
585                let row_ids = records
586                    .column(0)
587                    .as_any()
588                    .downcast_ref::<UInt64Array>()
589                    .unwrap();
590                let row_ids = row_ids.iter().flatten().collect_vec();
591                Ok(SearchResult::at_most(RowAddrTreeMap::from_iter(row_ids)))
592            }
593        }
594    }
595
596    fn can_remap(&self) -> bool {
597        true
598    }
599
600    async fn remap(
601        &self,
602        mapping: &HashMap<u64, Option<u64>>,
603        dest_store: &dyn IndexStore,
604    ) -> Result<CreatedIndex> {
605        self.to_builder()
606            .remap(mapping, self.store.clone(), dest_store)
607            .await?;
608
609        let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
610
611        // Use version 0 for Arrow format (legacy), version 1 for Fst format (new)
612        let index_version = match self.token_set_format {
613            TokenSetFormat::Arrow => 0,
614            TokenSetFormat::Fst => INVERTED_INDEX_VERSION,
615        };
616
617        Ok(CreatedIndex {
618            index_details: prost_types::Any::from_msg(&details).unwrap(),
619            index_version,
620        })
621    }
622
623    async fn update(
624        &self,
625        new_data: SendableRecordBatchStream,
626        dest_store: &dyn IndexStore,
627    ) -> Result<CreatedIndex> {
628        self.to_builder().update(new_data, dest_store).await?;
629
630        let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
631
632        // Use version 0 for Arrow format (legacy), version 1 for Fst format (new)
633        let index_version = match self.token_set_format {
634            TokenSetFormat::Arrow => 0,
635            TokenSetFormat::Fst => INVERTED_INDEX_VERSION,
636        };
637
638        Ok(CreatedIndex {
639            index_details: prost_types::Any::from_msg(&details).unwrap(),
640            index_version,
641        })
642    }
643
644    fn update_criteria(&self) -> UpdateCriteria {
645        let criteria = TrainingCriteria::new(TrainingOrdering::None).with_row_id();
646        if self.is_legacy() {
647            UpdateCriteria::requires_old_data(criteria)
648        } else {
649            UpdateCriteria::only_new_data(criteria)
650        }
651    }
652
653    fn derive_index_params(&self) -> Result<ScalarIndexParams> {
654        let mut params = self.params.clone();
655        if params.base_tokenizer.is_empty() {
656            params.base_tokenizer = "simple".to_string();
657        }
658
659        let params_json = serde_json::to_string(&params)?;
660
661        Ok(ScalarIndexParams {
662            index_type: BuiltinIndexType::Inverted.as_str().to_string(),
663            params: Some(params_json),
664        })
665    }
666}
667
668#[derive(Debug, Clone, DeepSizeOf)]
669pub struct InvertedPartition {
670    // 0 for legacy format
671    id: u64,
672    store: Arc<dyn IndexStore>,
673    pub(crate) tokens: TokenSet,
674    pub(crate) inverted_list: Arc<PostingListReader>,
675    pub(crate) docs: DocSet,
676    token_set_format: TokenSetFormat,
677}
678
679impl InvertedPartition {
680    /// Check if this partition belongs to the specified fragment.
681    ///
682    /// This method encapsulates the bit manipulation logic for fragment filtering
683    /// in distributed indexing scenarios.
684    ///
685    /// # Arguments
686    /// * `fragment_mask` - A mask with fragment_id in high 32 bits
687    ///
688    /// # Returns
689    /// * `true` if the partition belongs to the fragment, `false` otherwise
690    pub fn belongs_to_fragment(&self, fragment_mask: u64) -> bool {
691        (self.id() & fragment_mask) == fragment_mask
692    }
693
694    pub fn id(&self) -> u64 {
695        self.id
696    }
697
698    pub fn store(&self) -> &dyn IndexStore {
699        self.store.as_ref()
700    }
701
702    pub fn is_legacy(&self) -> bool {
703        self.inverted_list.lengths.is_none()
704    }
705
706    pub async fn load(
707        store: Arc<dyn IndexStore>,
708        id: u64,
709        frag_reuse_index: Option<Arc<FragReuseIndex>>,
710        index_cache: &LanceCache,
711        token_set_format: TokenSetFormat,
712    ) -> Result<Self> {
713        let token_file = store.open_index_file(&token_file_path(id)).await?;
714        let tokens = TokenSet::load(token_file, token_set_format).await?;
715        let invert_list_file = store.open_index_file(&posting_file_path(id)).await?;
716        let inverted_list = PostingListReader::try_new(invert_list_file, index_cache).await?;
717        let docs_file = store.open_index_file(&doc_file_path(id)).await?;
718        let docs = DocSet::load(docs_file, false, frag_reuse_index).await?;
719
720        Ok(Self {
721            id,
722            store,
723            tokens,
724            inverted_list: Arc::new(inverted_list),
725            docs,
726            token_set_format,
727        })
728    }
729
730    fn map(&self, token: &str) -> Option<u32> {
731        self.tokens.get(token)
732    }
733
734    pub fn expand_fuzzy(&self, tokens: &Tokens, params: &FtsSearchParams) -> Result<Tokens> {
735        let mut new_tokens = Vec::with_capacity(min(tokens.len(), params.max_expansions));
736        for token in tokens {
737            let fuzziness = match params.fuzziness {
738                Some(fuzziness) => fuzziness,
739                None => MatchQuery::auto_fuzziness(token),
740            };
741            let lev =
742                fst::automaton::Levenshtein::new(token, fuzziness).map_err(|e| Error::Index {
743                    message: format!("failed to construct the fuzzy query: {}", e),
744                    location: location!(),
745                })?;
746
747            let base_len = tokens.token_type().prefix_len(token) as u32;
748            if let TokenMap::Fst(ref map) = self.tokens.tokens {
749                match base_len + params.prefix_length {
750                    0 => take_fst_keys(map.search(lev), &mut new_tokens, params.max_expansions),
751                    prefix_length => {
752                        let prefix = &token[..min(prefix_length as usize, token.len())];
753                        let prefix = fst::automaton::Str::new(prefix).starts_with();
754                        take_fst_keys(
755                            map.search(lev.intersection(prefix)),
756                            &mut new_tokens,
757                            params.max_expansions,
758                        )
759                    }
760                }
761            } else {
762                return Err(Error::Index {
763                    message: "tokens is not fst, which is not expected".to_owned(),
764                    location: location!(),
765                });
766            }
767        }
768        Ok(Tokens::new(new_tokens, tokens.token_type().clone()))
769    }
770
771    // search the documents that contain the query
772    // return the doc info and the doc length
773    // ref: https://en.wikipedia.org/wiki/Okapi_BM25
774    #[instrument(level = "debug", skip_all)]
775    pub async fn load_posting_lists(
776        &self,
777        tokens: &Tokens,
778        params: &FtsSearchParams,
779        metrics: &dyn MetricsCollector,
780    ) -> Result<Vec<PostingIterator>> {
781        let is_fuzzy = matches!(params.fuzziness, Some(n) if n != 0);
782        let is_phrase_query = params.phrase_slop.is_some();
783        let tokens = match is_fuzzy {
784            true => self.expand_fuzzy(tokens, params)?,
785            false => tokens.clone(),
786        };
787        let mut token_ids = Vec::with_capacity(tokens.len());
788        for token in tokens {
789            let token_id = self.map(&token);
790            if let Some(token_id) = token_id {
791                token_ids.push((token_id, token));
792            } else if is_phrase_query {
793                // if the token is not found, we can't do phrase query
794                return Ok(Vec::new());
795            }
796        }
797        if token_ids.is_empty() {
798            return Ok(Vec::new());
799        }
800        if !is_phrase_query {
801            // remove duplicates
802            token_ids.sort_unstable_by_key(|(token_id, _)| *token_id);
803            token_ids.dedup_by_key(|(token_id, _)| *token_id);
804        }
805
806        let num_docs = self.docs.len();
807        stream::iter(token_ids)
808            .enumerate()
809            .map(|(position, (token_id, token))| async move {
810                let posting = self
811                    .inverted_list
812                    .posting_list(token_id, is_phrase_query, metrics)
813                    .await?;
814
815                Result::Ok(PostingIterator::new(
816                    token,
817                    token_id,
818                    position as u32,
819                    posting,
820                    num_docs,
821                ))
822            })
823            .buffered(self.store.io_parallelism())
824            .try_collect::<Vec<_>>()
825            .await
826    }
827
828    #[instrument(level = "debug", skip_all)]
829    pub fn bm25_search(
830        &self,
831        params: &FtsSearchParams,
832        operator: Operator,
833        mask: Arc<RowAddrMask>,
834        postings: Vec<PostingIterator>,
835        metrics: &dyn MetricsCollector,
836    ) -> Result<Vec<DocCandidate>> {
837        if postings.is_empty() {
838            return Ok(Vec::new());
839        }
840
841        // let local_metrics = LocalMetricsCollector::default();
842        let scorer = IndexBM25Scorer::new(std::iter::once(self));
843        let mut wand = Wand::new(operator, postings.into_iter(), &self.docs, scorer);
844        let hits = wand.search(params, mask, metrics)?;
845        // local_metrics.dump_into(metrics);
846        Ok(hits)
847    }
848
849    pub async fn into_builder(self) -> Result<InnerBuilder> {
850        let mut builder = InnerBuilder::new(
851            self.id,
852            self.inverted_list.has_positions(),
853            self.token_set_format,
854        );
855        builder.tokens = self.tokens;
856        builder.docs = self.docs;
857
858        builder
859            .posting_lists
860            .reserve_exact(self.inverted_list.len());
861        for posting_list in self
862            .inverted_list
863            .read_all(self.inverted_list.has_positions())
864            .await?
865        {
866            let posting_list = posting_list?;
867            builder
868                .posting_lists
869                .push(posting_list.into_builder(&builder.docs));
870        }
871        Ok(builder)
872    }
873}
874
875// at indexing, we use HashMap because we need it to be mutable,
876// at searching, we use fst::Map because it's more efficient
877#[derive(Debug, Clone)]
878pub enum TokenMap {
879    HashMap(HashMap<String, u32>),
880    Fst(fst::Map<Vec<u8>>),
881}
882
883impl Default for TokenMap {
884    fn default() -> Self {
885        Self::HashMap(HashMap::new())
886    }
887}
888
889impl DeepSizeOf for TokenMap {
890    fn deep_size_of_children(&self, ctx: &mut deepsize::Context) -> usize {
891        match self {
892            Self::HashMap(map) => map.deep_size_of_children(ctx),
893            Self::Fst(map) => map.as_fst().size(),
894        }
895    }
896}
897
898impl TokenMap {
899    pub fn len(&self) -> usize {
900        match self {
901            Self::HashMap(map) => map.len(),
902            Self::Fst(map) => map.len(),
903        }
904    }
905
906    pub fn is_empty(&self) -> bool {
907        self.len() == 0
908    }
909}
910
911// TokenSet is a mapping from tokens to token ids
912#[derive(Debug, Clone, Default, DeepSizeOf)]
913pub struct TokenSet {
914    // token -> token_id
915    pub(crate) tokens: TokenMap,
916    pub(crate) next_id: u32,
917    total_length: usize,
918}
919
920impl TokenSet {
921    pub fn into_mut(self) -> Self {
922        let tokens = match self.tokens {
923            TokenMap::HashMap(map) => map,
924            TokenMap::Fst(map) => {
925                let mut new_map = HashMap::with_capacity(map.len());
926                let mut stream = map.into_stream();
927                while let Some((token, token_id)) = stream.next() {
928                    new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
929                }
930
931                new_map
932            }
933        };
934
935        Self {
936            tokens: TokenMap::HashMap(tokens),
937            next_id: self.next_id,
938            total_length: self.total_length,
939        }
940    }
941
942    pub fn len(&self) -> usize {
943        self.tokens.len()
944    }
945
946    pub fn is_empty(&self) -> bool {
947        self.len() == 0
948    }
949
950    pub fn to_batch(self, format: TokenSetFormat) -> Result<RecordBatch> {
951        match format {
952            TokenSetFormat::Arrow => self.into_arrow_batch(),
953            TokenSetFormat::Fst => self.into_fst_batch(),
954        }
955    }
956
957    fn into_arrow_batch(self) -> Result<RecordBatch> {
958        let mut token_builder = StringBuilder::with_capacity(self.tokens.len(), self.total_length);
959        let mut token_id_builder = UInt32Builder::with_capacity(self.tokens.len());
960
961        match self.tokens {
962            TokenMap::Fst(map) => {
963                let mut stream = map.stream();
964                while let Some((token, token_id)) = stream.next() {
965                    token_builder.append_value(String::from_utf8_lossy(token));
966                    token_id_builder.append_value(token_id as u32);
967                }
968            }
969            TokenMap::HashMap(map) => {
970                for (token, token_id) in map.into_iter().sorted_unstable() {
971                    token_builder.append_value(token);
972                    token_id_builder.append_value(token_id);
973                }
974            }
975        }
976
977        let token_col = token_builder.finish();
978        let token_id_col = token_id_builder.finish();
979
980        let schema = arrow_schema::Schema::new(vec![
981            arrow_schema::Field::new(TOKEN_COL, DataType::Utf8, false),
982            arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false),
983        ]);
984
985        let batch = RecordBatch::try_new(
986            Arc::new(schema),
987            vec![
988                Arc::new(token_col) as ArrayRef,
989                Arc::new(token_id_col) as ArrayRef,
990            ],
991        )?;
992        Ok(batch)
993    }
994
995    fn into_fst_batch(mut self) -> Result<RecordBatch> {
996        let fst_map = match std::mem::take(&mut self.tokens) {
997            TokenMap::Fst(map) => map,
998            TokenMap::HashMap(map) => Self::build_fst_from_map(map)?,
999        };
1000        let bytes = fst_map.into_fst().into_inner();
1001
1002        let mut fst_builder = LargeBinaryBuilder::with_capacity(1, bytes.len());
1003        fst_builder.append_value(bytes);
1004        let fst_col = fst_builder.finish();
1005
1006        let mut next_id_builder = UInt32Builder::with_capacity(1);
1007        next_id_builder.append_value(self.next_id);
1008        let next_id_col = next_id_builder.finish();
1009
1010        let mut total_length_builder = UInt64Builder::with_capacity(1);
1011        total_length_builder.append_value(self.total_length as u64);
1012        let total_length_col = total_length_builder.finish();
1013
1014        let schema = arrow_schema::Schema::new(vec![
1015            arrow_schema::Field::new(TOKEN_FST_BYTES_COL, DataType::LargeBinary, false),
1016            arrow_schema::Field::new(TOKEN_NEXT_ID_COL, DataType::UInt32, false),
1017            arrow_schema::Field::new(TOKEN_TOTAL_LENGTH_COL, DataType::UInt64, false),
1018        ]);
1019
1020        let batch = RecordBatch::try_new(
1021            Arc::new(schema),
1022            vec![
1023                Arc::new(fst_col) as ArrayRef,
1024                Arc::new(next_id_col) as ArrayRef,
1025                Arc::new(total_length_col) as ArrayRef,
1026            ],
1027        )?;
1028        Ok(batch)
1029    }
1030
1031    fn build_fst_from_map(map: HashMap<String, u32>) -> Result<fst::Map<Vec<u8>>> {
1032        let mut entries: Vec<_> = map.into_iter().collect();
1033        entries.sort_unstable_by(|(lhs, _), (rhs, _)| lhs.cmp(rhs));
1034        let mut builder = fst::MapBuilder::memory();
1035        for (token, token_id) in entries {
1036            builder
1037                .insert(&token, token_id as u64)
1038                .map_err(|e| Error::Index {
1039                    message: format!("failed to insert token {}: {}", token, e),
1040                    location: location!(),
1041                })?;
1042        }
1043        Ok(builder.into_map())
1044    }
1045
1046    pub async fn load(reader: Arc<dyn IndexReader>, format: TokenSetFormat) -> Result<Self> {
1047        match format {
1048            TokenSetFormat::Arrow => Self::load_arrow(reader).await,
1049            TokenSetFormat::Fst => Self::load_fst(reader).await,
1050        }
1051    }
1052
1053    async fn load_arrow(reader: Arc<dyn IndexReader>) -> Result<Self> {
1054        let batch = reader.read_range(0..reader.num_rows(), None).await?;
1055
1056        let (tokens, next_id, total_length) = spawn_blocking(move || {
1057            let mut next_id = 0;
1058            let mut total_length = 0;
1059            let mut tokens = fst::MapBuilder::memory();
1060
1061            let token_col = batch[TOKEN_COL].as_string::<i32>();
1062            let token_id_col = batch[TOKEN_ID_COL].as_primitive::<datatypes::UInt32Type>();
1063
1064            for (token, &token_id) in token_col.iter().zip(token_id_col.values().iter()) {
1065                let token = token.ok_or(Error::Index {
1066                    message: "found null token in token set".to_owned(),
1067                    location: location!(),
1068                })?;
1069                next_id = next_id.max(token_id + 1);
1070                total_length += token.len();
1071                tokens
1072                    .insert(token, token_id as u64)
1073                    .map_err(|e| Error::Index {
1074                        message: format!("failed to insert token {}: {}", token, e),
1075                        location: location!(),
1076                    })?;
1077            }
1078
1079            Ok::<_, Error>((tokens.into_map(), next_id, total_length))
1080        })
1081        .await
1082        .map_err(|err| Error::Execution {
1083            message: format!("failed to spawn blocking task: {}", err),
1084            location: location!(),
1085        })??;
1086
1087        Ok(Self {
1088            tokens: TokenMap::Fst(tokens),
1089            next_id,
1090            total_length,
1091        })
1092    }
1093
1094    async fn load_fst(reader: Arc<dyn IndexReader>) -> Result<Self> {
1095        let batch = reader.read_range(0..reader.num_rows(), None).await?;
1096        if batch.num_rows() == 0 {
1097            return Err(Error::Index {
1098                message: "token set batch is empty".to_owned(),
1099                location: location!(),
1100            });
1101        }
1102
1103        let fst_col = batch[TOKEN_FST_BYTES_COL].as_binary::<i64>();
1104        let bytes = fst_col.value(0);
1105        let map = fst::Map::new(bytes.to_vec()).map_err(|e| Error::Index {
1106            message: format!("failed to load fst tokens: {}", e),
1107            location: location!(),
1108        })?;
1109
1110        let next_id_col = batch[TOKEN_NEXT_ID_COL].as_primitive::<datatypes::UInt32Type>();
1111        let total_length_col =
1112            batch[TOKEN_TOTAL_LENGTH_COL].as_primitive::<datatypes::UInt64Type>();
1113
1114        let next_id = next_id_col.values().first().copied().ok_or(Error::Index {
1115            message: "token next id column is empty".to_owned(),
1116            location: location!(),
1117        })?;
1118
1119        let total_length = total_length_col
1120            .values()
1121            .first()
1122            .copied()
1123            .ok_or(Error::Index {
1124                message: "token total length column is empty".to_owned(),
1125                location: location!(),
1126            })?;
1127
1128        Ok(Self {
1129            tokens: TokenMap::Fst(map),
1130            next_id,
1131            total_length: usize::try_from(total_length).map_err(|_| Error::Index {
1132                message: format!("token total length {} overflows usize", total_length),
1133                location: location!(),
1134            })?,
1135        })
1136    }
1137
1138    pub fn add(&mut self, token: String) -> u32 {
1139        let next_id = self.next_id();
1140        let len = token.len();
1141        let token_id = match self.tokens {
1142            TokenMap::HashMap(ref mut map) => *map.entry(token).or_insert(next_id),
1143            _ => unreachable!("tokens must be HashMap while indexing"),
1144        };
1145
1146        // add token if it doesn't exist
1147        if token_id == next_id {
1148            self.next_id += 1;
1149            self.total_length += len;
1150        }
1151
1152        token_id
1153    }
1154
1155    pub(crate) fn get_or_add(&mut self, token: &str) -> u32 {
1156        let next_id = self.next_id;
1157        match self.tokens {
1158            TokenMap::HashMap(ref mut map) => {
1159                if let Some(&token_id) = map.get(token) {
1160                    return token_id;
1161                }
1162
1163                map.insert(token.to_owned(), next_id);
1164            }
1165            _ => unreachable!("tokens must be HashMap while indexing"),
1166        }
1167
1168        self.next_id += 1;
1169        self.total_length += token.len();
1170        next_id
1171    }
1172
1173    pub fn get(&self, token: &str) -> Option<u32> {
1174        match self.tokens {
1175            TokenMap::HashMap(ref map) => map.get(token).copied(),
1176            TokenMap::Fst(ref map) => map.get(token).map(|id| id as u32),
1177        }
1178    }
1179
1180    // the `removed_token_ids` must be sorted
1181    pub fn remap(&mut self, removed_token_ids: &[u32]) {
1182        if removed_token_ids.is_empty() {
1183            return;
1184        }
1185
1186        let mut map = match std::mem::take(&mut self.tokens) {
1187            TokenMap::HashMap(map) => map,
1188            TokenMap::Fst(map) => {
1189                let mut new_map = HashMap::with_capacity(map.len());
1190                let mut stream = map.into_stream();
1191                while let Some((token, token_id)) = stream.next() {
1192                    new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
1193                }
1194
1195                new_map
1196            }
1197        };
1198
1199        map.retain(
1200            |_, token_id| match removed_token_ids.binary_search(token_id) {
1201                Ok(_) => false,
1202                Err(index) => {
1203                    *token_id -= index as u32;
1204                    true
1205                }
1206            },
1207        );
1208
1209        self.tokens = TokenMap::HashMap(map);
1210    }
1211
1212    pub fn next_id(&self) -> u32 {
1213        self.next_id
1214    }
1215}
1216
1217pub struct PostingListReader {
1218    reader: Arc<dyn IndexReader>,
1219
1220    // legacy format only
1221    offsets: Option<Vec<usize>>,
1222
1223    // from metadata for legacy format
1224    // from column for new format
1225    max_scores: Option<Vec<f32>>,
1226
1227    // new format only
1228    lengths: Option<Vec<u32>>,
1229
1230    has_position: bool,
1231
1232    index_cache: WeakLanceCache,
1233}
1234
1235impl std::fmt::Debug for PostingListReader {
1236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1237        f.debug_struct("InvertedListReader")
1238            .field("offsets", &self.offsets)
1239            .field("max_scores", &self.max_scores)
1240            .finish()
1241    }
1242}
1243
1244impl DeepSizeOf for PostingListReader {
1245    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
1246        self.offsets.deep_size_of_children(context)
1247            + self.max_scores.deep_size_of_children(context)
1248            + self.lengths.deep_size_of_children(context)
1249    }
1250}
1251
1252impl PostingListReader {
1253    pub(crate) async fn try_new(
1254        reader: Arc<dyn IndexReader>,
1255        index_cache: &LanceCache,
1256    ) -> Result<Self> {
1257        let has_position = reader.schema().field(POSITION_COL).is_some();
1258        let (offsets, max_scores, lengths) = if reader.schema().field(POSTING_COL).is_none() {
1259            let (offsets, max_scores) = Self::load_metadata(reader.schema())?;
1260            (Some(offsets), max_scores, None)
1261        } else {
1262            let metadata = reader
1263                .read_range(0..reader.num_rows(), Some(&[MAX_SCORE_COL, LENGTH_COL]))
1264                .await?;
1265            let max_scores = metadata[MAX_SCORE_COL]
1266                .as_primitive::<Float32Type>()
1267                .values()
1268                .to_vec();
1269            let lengths = metadata[LENGTH_COL]
1270                .as_primitive::<UInt32Type>()
1271                .values()
1272                .to_vec();
1273            (None, Some(max_scores), Some(lengths))
1274        };
1275
1276        Ok(Self {
1277            reader,
1278            offsets,
1279            max_scores,
1280            lengths,
1281            has_position,
1282            index_cache: WeakLanceCache::from(index_cache),
1283        })
1284    }
1285
1286    // for legacy format
1287    // returns the offsets and max scores
1288    fn load_metadata(
1289        schema: &lance_core::datatypes::Schema,
1290    ) -> Result<(Vec<usize>, Option<Vec<f32>>)> {
1291        let offsets = schema.metadata.get("offsets").ok_or(Error::Index {
1292            message: "offsets not found in metadata".to_owned(),
1293            location: location!(),
1294        })?;
1295        let offsets = serde_json::from_str(offsets)?;
1296
1297        let max_scores = schema
1298            .metadata
1299            .get("max_scores")
1300            .map(|max_scores| serde_json::from_str(max_scores))
1301            .transpose()?;
1302        Ok((offsets, max_scores))
1303    }
1304
1305    // the number of posting lists
1306    pub fn len(&self) -> usize {
1307        match self.offsets {
1308            Some(ref offsets) => offsets.len(),
1309            None => self.reader.num_rows(),
1310        }
1311    }
1312
1313    pub fn is_empty(&self) -> bool {
1314        self.len() == 0
1315    }
1316
1317    pub(crate) fn has_positions(&self) -> bool {
1318        self.has_position
1319    }
1320
1321    pub(crate) fn posting_len(&self, token_id: u32) -> usize {
1322        let token_id = token_id as usize;
1323
1324        match self.offsets {
1325            Some(ref offsets) => {
1326                let next_offset = offsets
1327                    .get(token_id + 1)
1328                    .copied()
1329                    .unwrap_or(self.reader.num_rows());
1330                next_offset - offsets[token_id]
1331            }
1332            None => {
1333                if let Some(lengths) = &self.lengths {
1334                    lengths[token_id] as usize
1335                } else {
1336                    panic!("posting list reader is not initialized")
1337                }
1338            }
1339        }
1340    }
1341
1342    pub(crate) async fn posting_batch(
1343        &self,
1344        token_id: u32,
1345        with_position: bool,
1346    ) -> Result<RecordBatch> {
1347        if self.offsets.is_some() {
1348            self.posting_batch_legacy(token_id, with_position).await
1349        } else {
1350            let token_id = token_id as usize;
1351            let columns = if with_position {
1352                vec![POSTING_COL, POSITION_COL]
1353            } else {
1354                vec![POSTING_COL]
1355            };
1356            let batch = self
1357                .reader
1358                .read_range(token_id..token_id + 1, Some(&columns))
1359                .await?;
1360            Ok(batch)
1361        }
1362    }
1363
1364    async fn posting_batch_legacy(
1365        &self,
1366        token_id: u32,
1367        with_position: bool,
1368    ) -> Result<RecordBatch> {
1369        let mut columns = vec![ROW_ID, FREQUENCY_COL];
1370        if with_position {
1371            columns.push(POSITION_COL);
1372        }
1373
1374        let length = self.posting_len(token_id);
1375        let token_id = token_id as usize;
1376        let offset = self.offsets.as_ref().unwrap()[token_id];
1377        let batch = self
1378            .reader
1379            .read_range(offset..offset + length, Some(&columns))
1380            .await?;
1381        Ok(batch)
1382    }
1383
1384    #[instrument(level = "debug", skip(self, metrics))]
1385    pub(crate) async fn posting_list(
1386        &self,
1387        token_id: u32,
1388        is_phrase_query: bool,
1389        metrics: &dyn MetricsCollector,
1390    ) -> Result<PostingList> {
1391        let cache_key = PostingListKey { token_id };
1392        let mut posting = self
1393            .index_cache
1394            .get_or_insert_with_key(cache_key, || async move {
1395                metrics.record_part_load();
1396                info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="inverted", part_id=token_id);
1397                let batch = self.posting_batch(token_id, false).await?;
1398                self.posting_list_from_batch(&batch, token_id)
1399            })
1400            .await?
1401            .as_ref()
1402            .clone();
1403
1404        if is_phrase_query {
1405            // hit the cache and when the cache was populated, the positions column was not loaded
1406            let positions = self.read_positions(token_id).await?;
1407            posting.set_positions(positions);
1408        }
1409
1410        Ok(posting)
1411    }
1412
1413    pub(crate) fn posting_list_from_batch(
1414        &self,
1415        batch: &RecordBatch,
1416        token_id: u32,
1417    ) -> Result<PostingList> {
1418        let posting_list = PostingList::from_batch(
1419            batch,
1420            self.max_scores
1421                .as_ref()
1422                .map(|max_scores| max_scores[token_id as usize]),
1423            self.lengths
1424                .as_ref()
1425                .map(|lengths| lengths[token_id as usize]),
1426        )?;
1427        Ok(posting_list)
1428    }
1429
1430    async fn prewarm(&self) -> Result<()> {
1431        let batch = self.read_batch(false).await?;
1432        for token_id in 0..self.len() {
1433            let posting_range = self.posting_list_range(token_id as u32);
1434            let batch = batch.slice(posting_range.start, posting_range.end - posting_range.start);
1435            // Apply shrink_to_fit to create a deep copy with compacted buffers
1436            // This ensures each cached entry has its own memory, not shared references
1437            let batch = batch.shrink_to_fit()?;
1438            let posting_list = self.posting_list_from_batch(&batch, token_id as u32)?;
1439            let inserted = self
1440                .index_cache
1441                .insert_with_key(
1442                    &PostingListKey {
1443                        token_id: token_id as u32,
1444                    },
1445                    Arc::new(posting_list),
1446                )
1447                .await;
1448
1449            if !inserted {
1450                return Err(Error::Internal {
1451                    message: "Failed to prewarm index: cache is no longer available".to_string(),
1452                    location: location!(),
1453                });
1454            }
1455        }
1456
1457        Ok(())
1458    }
1459
1460    pub(crate) async fn read_batch(&self, with_position: bool) -> Result<RecordBatch> {
1461        let columns = self.posting_columns(with_position);
1462        let batch = self
1463            .reader
1464            .read_range(0..self.reader.num_rows(), Some(&columns))
1465            .await?;
1466        Ok(batch)
1467    }
1468
1469    pub(crate) async fn read_all(
1470        &self,
1471        with_position: bool,
1472    ) -> Result<impl Iterator<Item = Result<PostingList>> + '_> {
1473        let batch = self.read_batch(with_position).await?;
1474        Ok((0..self.len()).map(move |i| {
1475            let token_id = i as u32;
1476            let range = self.posting_list_range(token_id);
1477            let batch = batch.slice(i, range.end - range.start);
1478            self.posting_list_from_batch(&batch, token_id)
1479        }))
1480    }
1481
1482    async fn read_positions(&self, token_id: u32) -> Result<ListArray> {
1483        let positions = self.index_cache.get_or_insert_with_key(PositionKey { token_id }, || async move {
1484            let batch = self
1485                .reader
1486                .read_range(self.posting_list_range(token_id), Some(&[POSITION_COL]))
1487                .await.map_err(|e| {
1488                    match e {
1489                        Error::Schema { .. } => Error::invalid_input(
1490                            "position is not found but required for phrase queries, try recreating the index with position".to_owned(),
1491                            location!(),
1492                        ),
1493                        e => e
1494                    }
1495                })?;
1496            Result::Ok(Positions(batch[POSITION_COL]
1497                .as_list::<i32>()
1498                .clone()))
1499        }).await?;
1500        Ok(positions.0.clone())
1501    }
1502
1503    fn posting_list_range(&self, token_id: u32) -> Range<usize> {
1504        match self.offsets {
1505            Some(ref offsets) => {
1506                let offset = offsets[token_id as usize];
1507                let posting_len = self.posting_len(token_id);
1508                offset..offset + posting_len
1509            }
1510            None => {
1511                let token_id = token_id as usize;
1512                token_id..token_id + 1
1513            }
1514        }
1515    }
1516
1517    fn posting_columns(&self, with_position: bool) -> Vec<&'static str> {
1518        let mut base_columns = match self.offsets {
1519            Some(_) => vec![ROW_ID, FREQUENCY_COL],
1520            None => vec![POSTING_COL],
1521        };
1522        if with_position {
1523            base_columns.push(POSITION_COL);
1524        }
1525        base_columns
1526    }
1527}
1528
1529/// New type just to allow Positions implement DeepSizeOf so it can be put
1530/// in the cache.
1531#[derive(Clone)]
1532pub struct Positions(ListArray);
1533
1534impl DeepSizeOf for Positions {
1535    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1536        self.0.get_buffer_memory_size()
1537    }
1538}
1539
1540// Cache key implementations for type-safe cache access
1541#[derive(Debug, Clone)]
1542pub struct PostingListKey {
1543    pub token_id: u32,
1544}
1545
1546impl CacheKey for PostingListKey {
1547    type ValueType = PostingList;
1548
1549    fn key(&self) -> std::borrow::Cow<'_, str> {
1550        format!("postings-{}", self.token_id).into()
1551    }
1552}
1553
1554#[derive(Debug, Clone)]
1555pub struct PositionKey {
1556    pub token_id: u32,
1557}
1558
1559impl CacheKey for PositionKey {
1560    type ValueType = Positions;
1561
1562    fn key(&self) -> std::borrow::Cow<'_, str> {
1563        format!("positions-{}", self.token_id).into()
1564    }
1565}
1566
1567#[derive(Debug, Clone, DeepSizeOf)]
1568pub enum PostingList {
1569    Plain(PlainPostingList),
1570    Compressed(CompressedPostingList),
1571}
1572
1573impl PostingList {
1574    pub fn from_batch(
1575        batch: &RecordBatch,
1576        max_score: Option<f32>,
1577        length: Option<u32>,
1578    ) -> Result<Self> {
1579        match batch.column_by_name(POSTING_COL) {
1580            Some(_) => {
1581                debug_assert!(max_score.is_some() && length.is_some());
1582                let posting =
1583                    CompressedPostingList::from_batch(batch, max_score.unwrap(), length.unwrap());
1584                Ok(Self::Compressed(posting))
1585            }
1586            None => {
1587                let posting = PlainPostingList::from_batch(batch, max_score);
1588                Ok(Self::Plain(posting))
1589            }
1590        }
1591    }
1592
1593    pub fn iter(&self) -> PostingListIterator<'_> {
1594        PostingListIterator::new(self)
1595    }
1596
1597    pub fn has_position(&self) -> bool {
1598        match self {
1599            Self::Plain(posting) => posting.positions.is_some(),
1600            Self::Compressed(posting) => posting.positions.is_some(),
1601        }
1602    }
1603
1604    pub fn set_positions(&mut self, positions: ListArray) {
1605        match self {
1606            Self::Plain(posting) => posting.positions = Some(positions),
1607            Self::Compressed(posting) => {
1608                posting.positions = Some(positions.value(0).as_list::<i32>().clone());
1609            }
1610        }
1611    }
1612
1613    pub fn max_score(&self) -> Option<f32> {
1614        match self {
1615            Self::Plain(posting) => posting.max_score,
1616            Self::Compressed(posting) => Some(posting.max_score),
1617        }
1618    }
1619
1620    pub fn len(&self) -> usize {
1621        match self {
1622            Self::Plain(posting) => posting.len(),
1623            Self::Compressed(posting) => posting.length as usize,
1624        }
1625    }
1626
1627    pub fn is_empty(&self) -> bool {
1628        self.len() == 0
1629    }
1630
1631    pub fn into_builder(self, docs: &DocSet) -> PostingListBuilder {
1632        let mut builder = PostingListBuilder::new(self.has_position());
1633        match self {
1634            // legacy format
1635            Self::Plain(posting) => {
1636                // convert the posting list to the new format:
1637                // 1. map row ids to doc ids
1638                // 2. sort the posting list by doc ids
1639                struct Item {
1640                    doc_id: u32,
1641                    positions: PositionRecorder,
1642                }
1643                let doc_ids = docs
1644                    .row_ids
1645                    .iter()
1646                    .enumerate()
1647                    .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
1648                    .collect::<HashMap<_, _>>();
1649                let mut items = Vec::with_capacity(posting.len());
1650                for (row_id, freq, positions) in posting.iter() {
1651                    let freq = freq as u32;
1652                    let positions = match positions {
1653                        Some(positions) => {
1654                            PositionRecorder::Position(positions.collect::<Vec<_>>().into())
1655                        }
1656                        None => PositionRecorder::Count(freq),
1657                    };
1658                    items.push(Item {
1659                        doc_id: doc_ids[&row_id],
1660                        positions,
1661                    });
1662                }
1663                items.sort_unstable_by_key(|item| item.doc_id);
1664                for item in items {
1665                    builder.add(item.doc_id, item.positions);
1666                }
1667            }
1668            Self::Compressed(posting) => {
1669                posting.iter().for_each(|(doc_id, freq, positions)| {
1670                    let positions = match positions {
1671                        Some(positions) => {
1672                            PositionRecorder::Position(positions.collect::<Vec<_>>().into())
1673                        }
1674                        None => PositionRecorder::Count(freq),
1675                    };
1676                    builder.add(doc_id, positions);
1677                });
1678            }
1679        }
1680        builder
1681    }
1682}
1683
1684#[derive(Debug, PartialEq, Clone)]
1685pub struct PlainPostingList {
1686    pub row_ids: ScalarBuffer<u64>,
1687    pub frequencies: ScalarBuffer<f32>,
1688    pub max_score: Option<f32>,
1689    pub positions: Option<ListArray>, // List of Int32
1690}
1691
1692impl DeepSizeOf for PlainPostingList {
1693    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1694        self.row_ids.len() * std::mem::size_of::<u64>()
1695            + self.frequencies.len() * std::mem::size_of::<u32>()
1696            + self
1697                .positions
1698                .as_ref()
1699                .map(|positions| positions.get_buffer_memory_size())
1700                .unwrap_or(0)
1701    }
1702}
1703
1704impl PlainPostingList {
1705    pub fn new(
1706        row_ids: ScalarBuffer<u64>,
1707        frequencies: ScalarBuffer<f32>,
1708        max_score: Option<f32>,
1709        positions: Option<ListArray>,
1710    ) -> Self {
1711        Self {
1712            row_ids,
1713            frequencies,
1714            max_score,
1715            positions,
1716        }
1717    }
1718
1719    pub fn from_batch(batch: &RecordBatch, max_score: Option<f32>) -> Self {
1720        let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().values().clone();
1721        let frequencies = batch[FREQUENCY_COL]
1722            .as_primitive::<Float32Type>()
1723            .values()
1724            .clone();
1725        let positions = batch
1726            .column_by_name(POSITION_COL)
1727            .map(|col| col.as_list::<i32>().clone());
1728
1729        Self::new(row_ids, frequencies, max_score, positions)
1730    }
1731
1732    pub fn len(&self) -> usize {
1733        self.row_ids.len()
1734    }
1735
1736    pub fn is_empty(&self) -> bool {
1737        self.len() == 0
1738    }
1739
1740    pub fn iter(&self) -> PlainPostingListIterator<'_> {
1741        Box::new(
1742            self.row_ids
1743                .iter()
1744                .zip(self.frequencies.iter())
1745                .enumerate()
1746                .map(|(idx, (doc_id, freq))| {
1747                    (
1748                        *doc_id,
1749                        *freq,
1750                        self.positions.as_ref().map(|p| {
1751                            let start = p.value_offsets()[idx] as usize;
1752                            let end = p.value_offsets()[idx + 1] as usize;
1753                            Box::new(
1754                                p.values().as_primitive::<Int32Type>().values()[start..end]
1755                                    .iter()
1756                                    .map(|pos| *pos as u32),
1757                            ) as _
1758                        }),
1759                    )
1760                }),
1761        )
1762    }
1763
1764    #[inline]
1765    pub fn doc(&self, i: usize) -> LocatedDocInfo {
1766        LocatedDocInfo::new(self.row_ids[i], self.frequencies[i])
1767    }
1768
1769    pub fn positions(&self, index: usize) -> Option<Arc<dyn Array>> {
1770        self.positions
1771            .as_ref()
1772            .map(|positions| positions.value(index))
1773    }
1774
1775    pub fn max_score(&self) -> Option<f32> {
1776        self.max_score
1777    }
1778
1779    pub fn row_id(&self, i: usize) -> u64 {
1780        self.row_ids[i]
1781    }
1782}
1783
1784#[derive(Debug, PartialEq, Clone)]
1785pub struct CompressedPostingList {
1786    pub max_score: f32,
1787    pub length: u32,
1788    // each binary is a block of compressed data
1789    // that contains `BLOCK_SIZE` doc ids and then `BLOCK_SIZE` frequencies
1790    pub blocks: LargeBinaryArray,
1791    pub positions: Option<ListArray>,
1792}
1793
1794impl DeepSizeOf for CompressedPostingList {
1795    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1796        self.blocks.get_buffer_memory_size()
1797            + self
1798                .positions
1799                .as_ref()
1800                .map(|positions| positions.get_buffer_memory_size())
1801                .unwrap_or(0)
1802    }
1803}
1804
1805impl CompressedPostingList {
1806    pub fn new(
1807        blocks: LargeBinaryArray,
1808        max_score: f32,
1809        length: u32,
1810        positions: Option<ListArray>,
1811    ) -> Self {
1812        Self {
1813            max_score,
1814            length,
1815            blocks,
1816            positions,
1817        }
1818    }
1819
1820    pub fn from_batch(batch: &RecordBatch, max_score: f32, length: u32) -> Self {
1821        debug_assert_eq!(batch.num_rows(), 1);
1822        let blocks = batch[POSTING_COL]
1823            .as_list::<i32>()
1824            .value(0)
1825            .as_binary::<i64>()
1826            .clone();
1827        let positions = batch
1828            .column_by_name(POSITION_COL)
1829            .map(|col| col.as_list::<i32>().value(0).as_list::<i32>().clone());
1830
1831        Self {
1832            max_score,
1833            length,
1834            blocks,
1835            positions,
1836        }
1837    }
1838
1839    pub fn iter(&self) -> CompressedPostingListIterator {
1840        CompressedPostingListIterator::new(
1841            self.length as usize,
1842            self.blocks.clone(),
1843            self.positions.clone(),
1844        )
1845    }
1846
1847    pub fn block_max_score(&self, block_idx: usize) -> f32 {
1848        let block = self.blocks.value(block_idx);
1849        block[0..4].try_into().map(f32::from_le_bytes).unwrap()
1850    }
1851
1852    pub fn block_least_doc_id(&self, block_idx: usize) -> u32 {
1853        let block = self.blocks.value(block_idx);
1854        block[4..8].try_into().map(u32::from_le_bytes).unwrap()
1855    }
1856}
1857
1858#[derive(Debug)]
1859pub struct PostingListBuilder {
1860    pub doc_ids: ExpLinkedList<u32>,
1861    pub frequencies: ExpLinkedList<u32>,
1862    pub positions: Option<PositionBuilder>,
1863}
1864
1865impl PostingListBuilder {
1866    pub fn size(&self) -> u64 {
1867        (std::mem::size_of::<u32>() * self.doc_ids.len()
1868            + std::mem::size_of::<u32>() * self.frequencies.len()
1869            + self
1870                .positions
1871                .as_ref()
1872                .map(|positions| positions.size())
1873                .unwrap_or(0)) as u64
1874    }
1875
1876    pub fn has_positions(&self) -> bool {
1877        self.positions.is_some()
1878    }
1879
1880    pub fn new(with_position: bool) -> Self {
1881        Self {
1882            doc_ids: ExpLinkedList::new().with_capacity_limit(128),
1883            frequencies: ExpLinkedList::new().with_capacity_limit(128),
1884            positions: with_position.then(PositionBuilder::new),
1885        }
1886    }
1887
1888    pub fn len(&self) -> usize {
1889        self.doc_ids.len()
1890    }
1891
1892    pub fn is_empty(&self) -> bool {
1893        self.len() == 0
1894    }
1895
1896    pub fn iter(&self) -> impl Iterator<Item = (&u32, &u32, Option<&[u32]>)> {
1897        self.doc_ids
1898            .iter()
1899            .zip(self.frequencies.iter())
1900            .enumerate()
1901            .map(|(idx, (doc_id, freq))| {
1902                let positions = self.positions.as_ref().map(|positions| positions.get(idx));
1903                (doc_id, freq, positions)
1904            })
1905    }
1906
1907    pub fn add(&mut self, doc_id: u32, term_positions: PositionRecorder) {
1908        self.doc_ids.push(doc_id);
1909        self.frequencies.push(term_positions.len());
1910        if let Some(positions) = self.positions.as_mut() {
1911            positions.push(term_positions.into_vec());
1912        }
1913    }
1914
1915    // assume the posting list is sorted by doc id
1916    pub fn to_batch(self, block_max_scores: Vec<f32>) -> Result<RecordBatch> {
1917        let length = self.len();
1918        let max_score = block_max_scores.iter().copied().fold(f32::MIN, f32::max);
1919
1920        let schema = inverted_list_schema(self.has_positions());
1921        let compressed = compress_posting_list(
1922            self.doc_ids.len(),
1923            self.doc_ids.iter(),
1924            self.frequencies.iter(),
1925            block_max_scores.into_iter(),
1926        )?;
1927        let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, compressed.len() as i32]));
1928        let mut columns = vec![
1929            Arc::new(ListArray::try_new(
1930                Arc::new(Field::new("item", datatypes::DataType::LargeBinary, true)),
1931                offsets,
1932                Arc::new(compressed),
1933                None,
1934            )?) as ArrayRef,
1935            Arc::new(Float32Array::from_iter_values(std::iter::once(max_score))) as ArrayRef,
1936            Arc::new(UInt32Array::from_iter_values(std::iter::once(
1937                self.len() as u32
1938            ))) as ArrayRef,
1939        ];
1940
1941        if let Some(positions) = self.positions.as_ref() {
1942            let mut position_builder = ListBuilder::new(ListBuilder::with_capacity(
1943                LargeBinaryBuilder::new(),
1944                length,
1945            ));
1946            for index in 0..length {
1947                let positions_in_doc = positions.get(index);
1948                let compressed = compress_positions(positions_in_doc)?;
1949                let inner_builder = position_builder.values();
1950                inner_builder.append_value(compressed.into_iter());
1951            }
1952            position_builder.append(true);
1953            let position_col = position_builder.finish();
1954            columns.push(Arc::new(position_col));
1955        }
1956
1957        let batch = RecordBatch::try_new(schema, columns)?;
1958        Ok(batch)
1959    }
1960
1961    pub fn remap(&mut self, removed: &[u32]) {
1962        let mut cursor = 0;
1963        let mut new_doc_ids = ExpLinkedList::with_capacity(self.len());
1964        let mut new_frequencies = ExpLinkedList::with_capacity(self.len());
1965        let mut new_positions = self.positions.as_mut().map(|_| PositionBuilder::new());
1966        for (&doc_id, &freq, positions) in self.iter() {
1967            while cursor < removed.len() && removed[cursor] < doc_id {
1968                cursor += 1;
1969            }
1970            if cursor < removed.len() && removed[cursor] == doc_id {
1971                // this doc is removed
1972                continue;
1973            }
1974            // there are cursor removed docs before this doc
1975            // so we need to shift the doc id
1976            new_doc_ids.push(doc_id - cursor as u32);
1977            new_frequencies.push(freq);
1978            if let Some(new_positions) = new_positions.as_mut() {
1979                new_positions.push(positions.unwrap().to_vec());
1980            }
1981        }
1982
1983        self.doc_ids = new_doc_ids;
1984        self.frequencies = new_frequencies;
1985        self.positions = new_positions;
1986    }
1987}
1988
1989#[derive(Debug, Clone, DeepSizeOf)]
1990pub struct PositionBuilder {
1991    positions: Vec<u32>,
1992    offsets: Vec<i32>,
1993}
1994
1995impl Default for PositionBuilder {
1996    fn default() -> Self {
1997        Self::new()
1998    }
1999}
2000
2001impl PositionBuilder {
2002    pub fn new() -> Self {
2003        Self {
2004            positions: Vec::new(),
2005            offsets: vec![0],
2006        }
2007    }
2008
2009    pub fn size(&self) -> usize {
2010        std::mem::size_of::<u32>() * self.positions.len()
2011            + std::mem::size_of::<i32>() * self.offsets.len()
2012    }
2013
2014    pub fn total_len(&self) -> usize {
2015        self.positions.len()
2016    }
2017
2018    pub fn push(&mut self, positions: Vec<u32>) {
2019        self.positions.extend(positions);
2020        self.offsets.push(self.positions.len() as i32);
2021    }
2022
2023    pub fn get(&self, i: usize) -> &[u32] {
2024        let start = self.offsets[i] as usize;
2025        let end = self.offsets[i + 1] as usize;
2026        &self.positions[start..end]
2027    }
2028}
2029
2030impl From<Vec<Vec<u32>>> for PositionBuilder {
2031    fn from(positions: Vec<Vec<u32>>) -> Self {
2032        let mut builder = Self::new();
2033        builder.offsets.reserve(positions.len());
2034        for pos in positions {
2035            builder.push(pos);
2036        }
2037        builder
2038    }
2039}
2040
2041#[derive(Debug, Clone, DeepSizeOf, Copy)]
2042pub enum DocInfo {
2043    Located(LocatedDocInfo),
2044    Raw(RawDocInfo),
2045}
2046
2047impl DocInfo {
2048    pub fn doc_id(&self) -> u64 {
2049        match self {
2050            Self::Raw(info) => info.doc_id as u64,
2051            Self::Located(info) => info.row_id,
2052        }
2053    }
2054
2055    pub fn frequency(&self) -> u32 {
2056        match self {
2057            Self::Raw(info) => info.frequency,
2058            Self::Located(info) => info.frequency as u32,
2059        }
2060    }
2061}
2062
2063impl Eq for DocInfo {}
2064
2065impl PartialEq for DocInfo {
2066    fn eq(&self, other: &Self) -> bool {
2067        self.doc_id() == other.doc_id()
2068    }
2069}
2070
2071impl PartialOrd for DocInfo {
2072    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2073        Some(self.cmp(other))
2074    }
2075}
2076
2077impl Ord for DocInfo {
2078    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2079        self.doc_id().cmp(&other.doc_id())
2080    }
2081}
2082
2083#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
2084pub struct LocatedDocInfo {
2085    pub row_id: u64,
2086    pub frequency: f32,
2087}
2088
2089impl LocatedDocInfo {
2090    pub fn new(row_id: u64, frequency: f32) -> Self {
2091        Self { row_id, frequency }
2092    }
2093}
2094
2095impl Eq for LocatedDocInfo {}
2096
2097impl PartialEq for LocatedDocInfo {
2098    fn eq(&self, other: &Self) -> bool {
2099        self.row_id == other.row_id
2100    }
2101}
2102
2103impl PartialOrd for LocatedDocInfo {
2104    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2105        Some(self.cmp(other))
2106    }
2107}
2108
2109impl Ord for LocatedDocInfo {
2110    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2111        self.row_id.cmp(&other.row_id)
2112    }
2113}
2114
2115#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
2116pub struct RawDocInfo {
2117    pub doc_id: u32,
2118    pub frequency: u32,
2119}
2120
2121impl RawDocInfo {
2122    pub fn new(doc_id: u32, frequency: u32) -> Self {
2123        Self { doc_id, frequency }
2124    }
2125}
2126
2127impl Eq for RawDocInfo {}
2128
2129impl PartialEq for RawDocInfo {
2130    fn eq(&self, other: &Self) -> bool {
2131        self.doc_id == other.doc_id
2132    }
2133}
2134
2135impl PartialOrd for RawDocInfo {
2136    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2137        Some(self.cmp(other))
2138    }
2139}
2140
2141impl Ord for RawDocInfo {
2142    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2143        self.doc_id.cmp(&other.doc_id)
2144    }
2145}
2146
2147// DocSet is a mapping from row ids to the number of tokens in the document
2148// It's used to sort the documents by the bm25 score
2149#[derive(Debug, Clone, Default, DeepSizeOf)]
2150pub struct DocSet {
2151    row_ids: Vec<u64>,
2152    num_tokens: Vec<u32>,
2153    // (row_id, doc_id) pairs sorted by row_id
2154    inv: Vec<(u64, u32)>,
2155
2156    total_tokens: u64,
2157}
2158
2159impl DocSet {
2160    #[inline]
2161    pub fn len(&self) -> usize {
2162        self.row_ids.len()
2163    }
2164
2165    pub fn is_empty(&self) -> bool {
2166        self.len() == 0
2167    }
2168
2169    pub fn iter(&self) -> impl Iterator<Item = (&u64, &u32)> {
2170        self.row_ids.iter().zip(self.num_tokens.iter())
2171    }
2172
2173    pub fn row_id(&self, doc_id: u32) -> u64 {
2174        self.row_ids[doc_id as usize]
2175    }
2176
2177    pub fn doc_id(&self, row_id: u64) -> Option<u64> {
2178        if self.inv.is_empty() {
2179            // in legacy format, the row id is doc id
2180            match self.row_ids.binary_search(&row_id) {
2181                Ok(_) => Some(row_id),
2182                Err(_) => None,
2183            }
2184        } else {
2185            match self.inv.binary_search_by_key(&row_id, |x| x.0) {
2186                Ok(idx) => Some(self.inv[idx].1 as u64),
2187                Err(_) => None,
2188            }
2189        }
2190    }
2191    pub fn total_tokens_num(&self) -> u64 {
2192        self.total_tokens
2193    }
2194
2195    #[inline]
2196    pub fn average_length(&self) -> f32 {
2197        self.total_tokens as f32 / self.len() as f32
2198    }
2199
2200    pub fn calculate_block_max_scores<'a>(
2201        &self,
2202        doc_ids: impl Iterator<Item = &'a u32>,
2203        freqs: impl Iterator<Item = &'a u32>,
2204    ) -> Vec<f32> {
2205        let avgdl = self.average_length();
2206        let length = doc_ids.size_hint().0;
2207        let num_blocks = length.div_ceil(BLOCK_SIZE);
2208        let mut block_max_scores = Vec::with_capacity(num_blocks);
2209        let idf_scale = idf(length, self.len()) * (K1 + 1.0);
2210        let mut max_score = f32::MIN;
2211        for (i, (doc_id, freq)) in doc_ids.zip(freqs).enumerate() {
2212            let doc_norm = K1 * (1.0 - B + B * self.num_tokens(*doc_id) as f32 / avgdl);
2213            let freq = *freq as f32;
2214            let score = freq / (freq + doc_norm);
2215            if score > max_score {
2216                max_score = score;
2217            }
2218            if (i + 1) % BLOCK_SIZE == 0 {
2219                max_score *= idf_scale;
2220                block_max_scores.push(max_score);
2221                max_score = f32::MIN;
2222            }
2223        }
2224        if length % BLOCK_SIZE > 0 {
2225            max_score *= idf_scale;
2226            block_max_scores.push(max_score);
2227        }
2228        block_max_scores
2229    }
2230
2231    pub fn to_batch(&self) -> Result<RecordBatch> {
2232        let row_id_col = UInt64Array::from_iter_values(self.row_ids.iter().cloned());
2233        let num_tokens_col = UInt32Array::from_iter_values(self.num_tokens.iter().cloned());
2234
2235        let schema = arrow_schema::Schema::new(vec![
2236            arrow_schema::Field::new(ROW_ID, DataType::UInt64, false),
2237            arrow_schema::Field::new(NUM_TOKEN_COL, DataType::UInt32, false),
2238        ]);
2239
2240        let batch = RecordBatch::try_new(
2241            Arc::new(schema),
2242            vec![
2243                Arc::new(row_id_col) as ArrayRef,
2244                Arc::new(num_tokens_col) as ArrayRef,
2245            ],
2246        )?;
2247        Ok(batch)
2248    }
2249
2250    pub async fn load(
2251        reader: Arc<dyn IndexReader>,
2252        is_legacy: bool,
2253        frag_reuse_index: Option<Arc<FragReuseIndex>>,
2254    ) -> Result<Self> {
2255        let batch = reader.read_range(0..reader.num_rows(), None).await?;
2256        let row_id_col = batch[ROW_ID].as_primitive::<datatypes::UInt64Type>();
2257        let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::<datatypes::UInt32Type>();
2258
2259        // for legacy format, the row id is doc id; sorting keeps binary search viable
2260        if is_legacy {
2261            let (row_ids, num_tokens): (Vec<_>, Vec<_>) = row_id_col
2262                .values()
2263                .iter()
2264                .filter_map(|id| {
2265                    if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
2266                        frag_reuse_index_ref.remap_row_id(*id)
2267                    } else {
2268                        Some(*id)
2269                    }
2270                })
2271                .zip(num_tokens_col.values().iter())
2272                .sorted_unstable_by_key(|x| x.0)
2273                .unzip();
2274
2275            let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2276            return Ok(Self {
2277                row_ids,
2278                num_tokens,
2279                inv: Vec::new(),
2280                total_tokens,
2281            });
2282        }
2283
2284        // if frag reuse happened, we'll need to remap the row_ids. And after row_ids been
2285        // remapped, we'll need resort to make sure binary_search works.
2286        if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
2287            let mut row_ids = Vec::with_capacity(row_id_col.len());
2288            let mut num_tokens = Vec::with_capacity(num_tokens_col.len());
2289            for (row_id, num_token) in row_id_col.values().iter().zip(num_tokens_col.values()) {
2290                if let Some(new_row_id) = frag_reuse_index_ref.remap_row_id(*row_id) {
2291                    row_ids.push(new_row_id);
2292                    num_tokens.push(*num_token);
2293                }
2294            }
2295
2296            let mut inv: Vec<(u64, u32)> = row_ids
2297                .iter()
2298                .enumerate()
2299                .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
2300                .collect();
2301            inv.sort_unstable_by_key(|entry| entry.0);
2302
2303            let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2304            return Ok(Self {
2305                row_ids,
2306                num_tokens,
2307                inv,
2308                total_tokens,
2309            });
2310        }
2311
2312        let row_ids = row_id_col.values().to_vec();
2313        let num_tokens = num_tokens_col.values().to_vec();
2314        let mut inv: Vec<(u64, u32)> = row_ids
2315            .iter()
2316            .enumerate()
2317            .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
2318            .collect();
2319        if !row_ids.is_sorted() {
2320            inv.sort_unstable_by_key(|entry| entry.0);
2321        }
2322        let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2323        Ok(Self {
2324            row_ids,
2325            num_tokens,
2326            inv,
2327            total_tokens,
2328        })
2329    }
2330
2331    // remap the row ids to the new row ids
2332    // returns the removed doc ids
2333    pub fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Vec<u32> {
2334        let mut removed = Vec::new();
2335        let len = self.len();
2336        let row_ids = std::mem::replace(&mut self.row_ids, Vec::with_capacity(len));
2337        let num_tokens = std::mem::replace(&mut self.num_tokens, Vec::with_capacity(len));
2338        for (doc_id, (row_id, num_token)) in std::iter::zip(row_ids, num_tokens).enumerate() {
2339            match mapping.get(&row_id) {
2340                Some(Some(new_row_id)) => {
2341                    self.row_ids.push(*new_row_id);
2342                    self.num_tokens.push(num_token);
2343                }
2344                Some(None) => {
2345                    removed.push(doc_id as u32);
2346                }
2347                None => {
2348                    self.row_ids.push(row_id);
2349                    self.num_tokens.push(num_token);
2350                }
2351            }
2352        }
2353        removed
2354    }
2355
2356    #[inline]
2357    pub fn num_tokens(&self, doc_id: u32) -> u32 {
2358        self.num_tokens[doc_id as usize]
2359    }
2360
2361    // this can be used only if it's a legacy format,
2362    // which store the sorted row ids so that we can use binary search
2363    #[inline]
2364    pub fn num_tokens_by_row_id(&self, row_id: u64) -> u32 {
2365        self.row_ids
2366            .binary_search(&row_id)
2367            .map(|idx| self.num_tokens[idx])
2368            .unwrap_or(0)
2369    }
2370
2371    // append a document to the doc set
2372    // returns the doc_id (the number of documents before appending)
2373    pub fn append(&mut self, row_id: u64, num_tokens: u32) -> u32 {
2374        self.row_ids.push(row_id);
2375        self.num_tokens.push(num_tokens);
2376        self.total_tokens += num_tokens as u64;
2377        self.row_ids.len() as u32 - 1
2378    }
2379}
2380
2381pub fn flat_full_text_search(
2382    batches: &[&RecordBatch],
2383    doc_col: &str,
2384    query: &str,
2385    tokenizer: Option<Box<dyn LanceTokenizer>>,
2386) -> Result<Vec<u64>> {
2387    if batches.is_empty() {
2388        return Ok(vec![]);
2389    }
2390
2391    if is_phrase_query(query) {
2392        return Err(Error::invalid_input(
2393            "phrase query is not supported for flat full text search, try using FTS index",
2394            location!(),
2395        ));
2396    }
2397
2398    match batches[0][doc_col].data_type() {
2399        DataType::Utf8 => do_flat_full_text_search::<i32>(batches, doc_col, query, tokenizer),
2400        DataType::LargeUtf8 => do_flat_full_text_search::<i64>(batches, doc_col, query, tokenizer),
2401        data_type => Err(Error::invalid_input(
2402            format!("unsupported data type {} for inverted index", data_type),
2403            location!(),
2404        )),
2405    }
2406}
2407
2408fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
2409    batches: &[&RecordBatch],
2410    doc_col: &str,
2411    query: &str,
2412    tokenizer: Option<Box<dyn LanceTokenizer>>,
2413) -> Result<Vec<u64>> {
2414    let mut results = Vec::new();
2415    let mut tokenizer =
2416        tokenizer.unwrap_or_else(|| InvertedIndexParams::default().build().unwrap());
2417    let query_tokens = collect_query_tokens(query, &mut tokenizer, None);
2418
2419    for batch in batches {
2420        let row_id_array = batch[ROW_ID].as_primitive::<UInt64Type>();
2421        let doc_array = batch[doc_col].as_string::<Offset>();
2422        for i in 0..row_id_array.len() {
2423            let doc = doc_array.value(i);
2424            let doc_tokens = collect_doc_tokens(doc, &mut tokenizer, Some(&query_tokens));
2425            if !doc_tokens.is_empty() {
2426                results.push(row_id_array.value(i));
2427                assert!(doc.contains(query));
2428            }
2429        }
2430    }
2431
2432    Ok(results)
2433}
2434
2435#[allow(clippy::too_many_arguments)]
2436pub fn flat_bm25_search(
2437    batch: RecordBatch,
2438    doc_col: &str,
2439    query_tokens: &Tokens,
2440    tokenizer: &mut Box<dyn LanceTokenizer>,
2441    scorer: &mut MemBM25Scorer,
2442    schema: SchemaRef,
2443) -> std::result::Result<RecordBatch, DataFusionError> {
2444    let doc_iter = iter_str_array(&batch[doc_col]);
2445    let mut scores = Vec::with_capacity(batch.num_rows());
2446    for doc in doc_iter {
2447        let Some(doc) = doc else {
2448            scores.push(0.0);
2449            continue;
2450        };
2451
2452        let doc_tokens = collect_doc_tokens(doc, tokenizer, None);
2453        scorer.update(&doc_tokens);
2454        let doc_tokens = doc_tokens
2455            .into_iter()
2456            .filter(|t| query_tokens.contains(t))
2457            .collect::<Vec<_>>();
2458
2459        let doc_norm = K1 * (1.0 - B + B * doc_tokens.len() as f32 / scorer.avg_doc_length());
2460        let mut doc_token_count = HashMap::new();
2461        for token in doc_tokens {
2462            doc_token_count
2463                .entry(token)
2464                .and_modify(|count| *count += 1)
2465                .or_insert(1);
2466        }
2467        let mut score = 0.0;
2468        for token in query_tokens {
2469            let freq = doc_token_count.get(token).copied().unwrap_or_default() as f32;
2470
2471            let idf = idf(scorer.num_docs_containing_token(token), scorer.num_docs());
2472            score += idf * (freq * (K1 + 1.0) / (freq + doc_norm));
2473        }
2474        scores.push(score);
2475    }
2476
2477    let score_col = Arc::new(Float32Array::from(scores)) as ArrayRef;
2478    let batch = batch
2479        .try_with_column(SCORE_FIELD.clone(), score_col)?
2480        .project_by_schema(&schema)?;
2481    Ok(batch)
2482}
2483
2484pub fn flat_bm25_search_stream(
2485    input: SendableRecordBatchStream,
2486    doc_col: String,
2487    query: String,
2488    index: &Option<InvertedIndex>,
2489    schema: SchemaRef,
2490) -> SendableRecordBatchStream {
2491    let mut tokenizer = match index {
2492        Some(index) => index.tokenizer(),
2493        None => Box::new(TextTokenizer::new(
2494            tantivy::tokenizer::TextAnalyzer::builder(
2495                tantivy::tokenizer::SimpleTokenizer::default(),
2496            )
2497            .build(),
2498        )),
2499    };
2500    let tokens = collect_query_tokens(&query, &mut tokenizer, None);
2501
2502    let mut bm25_scorer = match index {
2503        Some(index) => {
2504            let index_bm25_scorer =
2505                IndexBM25Scorer::new(index.partitions.iter().map(|p| p.as_ref()));
2506            if index_bm25_scorer.num_docs() == 0 {
2507                MemBM25Scorer::new(0, 0, HashMap::new())
2508            } else {
2509                let mut token_docs = HashMap::with_capacity(tokens.len());
2510                for token in &tokens {
2511                    let token_nq = index_bm25_scorer.num_docs_containing_token(token).max(1);
2512                    token_docs.insert(token.clone(), token_nq);
2513                }
2514                MemBM25Scorer::new(
2515                    index_bm25_scorer.avg_doc_length() as u64 * index_bm25_scorer.num_docs() as u64,
2516                    index_bm25_scorer.num_docs(),
2517                    token_docs,
2518                )
2519            }
2520        }
2521        None => MemBM25Scorer::new(0, 0, HashMap::new()),
2522    };
2523
2524    let batch_schema = schema.clone();
2525    let stream = input.map(move |batch| {
2526        let batch = batch?;
2527
2528        let batch = flat_bm25_search(
2529            batch,
2530            &doc_col,
2531            &tokens,
2532            &mut tokenizer,
2533            &mut bm25_scorer,
2534            batch_schema.clone(),
2535        )?;
2536
2537        // filter out rows with score 0
2538        let score_col = batch[SCORE_COL].as_primitive::<Float32Type>();
2539        let mask = score_col
2540            .iter()
2541            .map(|score| score.is_some_and(|score| score > 0.0))
2542            .collect::<Vec<_>>();
2543        let mask = BooleanArray::from(mask);
2544        let batch = arrow::compute::filter_record_batch(&batch, &mask)?;
2545        debug_assert!(batch[ROW_ID].null_count() == 0, "flat FTS produces nulls");
2546        Ok(batch)
2547    });
2548
2549    Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream
2550}
2551
2552pub fn is_phrase_query(query: &str) -> bool {
2553    query.starts_with('\"') && query.ends_with('\"')
2554}
2555
2556#[cfg(test)]
2557mod tests {
2558    use crate::scalar::inverted::lance_tokenizer::DocType;
2559    use lance_core::cache::LanceCache;
2560    use lance_core::utils::tempfile::TempObjDir;
2561    use lance_io::object_store::ObjectStore;
2562
2563    use crate::metrics::NoOpMetricsCollector;
2564    use crate::prefilter::NoFilter;
2565    use crate::scalar::inverted::builder::{InnerBuilder, PositionRecorder};
2566    use crate::scalar::inverted::encoding::decompress_posting_list;
2567    use crate::scalar::inverted::query::{FtsSearchParams, Operator};
2568    use crate::scalar::lance_format::LanceIndexStore;
2569
2570    use super::*;
2571
2572    #[tokio::test]
2573    async fn test_posting_builder_remap() {
2574        let mut builder = PostingListBuilder::new(false);
2575        let n = BLOCK_SIZE + 3;
2576        for i in 0..n {
2577            builder.add(i as u32, PositionRecorder::Count(1));
2578        }
2579        let removed = vec![5, 7];
2580        builder.remap(&removed);
2581
2582        let mut expected = PostingListBuilder::new(false);
2583        for i in 0..n - removed.len() {
2584            expected.add(i as u32, PositionRecorder::Count(1));
2585        }
2586        assert_eq!(builder.doc_ids, expected.doc_ids);
2587        assert_eq!(builder.frequencies, expected.frequencies);
2588
2589        // BLOCK_SIZE + 3 elements should be reduced to BLOCK_SIZE + 1,
2590        // there are still 2 blocks.
2591        let batch = builder.to_batch(vec![1.0, 2.0]).unwrap();
2592        let (doc_ids, freqs) = decompress_posting_list(
2593            (n - removed.len()) as u32,
2594            batch[POSTING_COL]
2595                .as_list::<i32>()
2596                .value(0)
2597                .as_binary::<i64>(),
2598        )
2599        .unwrap();
2600        assert!(doc_ids
2601            .iter()
2602            .zip(expected.doc_ids.iter())
2603            .all(|(a, b)| a == b));
2604        assert!(freqs
2605            .iter()
2606            .zip(expected.frequencies.iter())
2607            .all(|(a, b)| a == b));
2608    }
2609
2610    #[tokio::test]
2611    async fn test_remap_to_empty_posting_list() {
2612        let tmpdir = TempObjDir::default();
2613        let store = Arc::new(LanceIndexStore::new(
2614            ObjectStore::local().into(),
2615            tmpdir.clone(),
2616            Arc::new(LanceCache::no_cache()),
2617        ));
2618
2619        let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default());
2620
2621        // index of docs:
2622        // 0: lance
2623        // 1: lake lake
2624        // 2: lake lake lake
2625        builder.tokens.add("lance".to_owned());
2626        builder.tokens.add("lake".to_owned());
2627        builder.posting_lists.push(PostingListBuilder::new(false));
2628        builder.posting_lists.push(PostingListBuilder::new(false));
2629        builder.posting_lists[0].add(0, PositionRecorder::Count(1));
2630        builder.posting_lists[1].add(1, PositionRecorder::Count(2));
2631        builder.posting_lists[1].add(2, PositionRecorder::Count(3));
2632        builder.docs.append(0, 1);
2633        builder.docs.append(1, 1);
2634        builder.docs.append(2, 1);
2635        builder.write(store.as_ref()).await.unwrap();
2636
2637        let index = InvertedPartition::load(
2638            store.clone(),
2639            0,
2640            None,
2641            &LanceCache::no_cache(),
2642            TokenSetFormat::default(),
2643        )
2644        .await
2645        .unwrap();
2646        let mut builder = index.into_builder().await.unwrap();
2647
2648        let mapping = HashMap::from([(0, None), (2, Some(3))]);
2649        builder.remap(&mapping).await.unwrap();
2650
2651        // after remap, the doc 0 is removed, and the doc 2 is updated to 3
2652        assert_eq!(builder.tokens.len(), 1);
2653        assert_eq!(builder.tokens.get("lake"), Some(0));
2654        assert_eq!(builder.posting_lists.len(), 1);
2655        assert_eq!(builder.posting_lists[0].len(), 2);
2656        assert_eq!(builder.docs.len(), 2);
2657        assert_eq!(builder.docs.row_id(0), 1);
2658        assert_eq!(builder.docs.row_id(1), 3);
2659
2660        builder.write(store.as_ref()).await.unwrap();
2661
2662        // remap to delete all docs
2663        let mapping = HashMap::from([(1, None), (3, None)]);
2664        builder.remap(&mapping).await.unwrap();
2665
2666        assert_eq!(builder.tokens.len(), 0);
2667        assert_eq!(builder.posting_lists.len(), 0);
2668        assert_eq!(builder.docs.len(), 0);
2669
2670        builder.write(store.as_ref()).await.unwrap();
2671    }
2672
2673    #[tokio::test]
2674    async fn test_posting_cache_conflict_across_partitions() {
2675        let tmpdir = TempObjDir::default();
2676        let store = Arc::new(LanceIndexStore::new(
2677            ObjectStore::local().into(),
2678            tmpdir.clone(),
2679            Arc::new(LanceCache::no_cache()),
2680        ));
2681
2682        // Create first partition with one token and posting list length 1
2683        let mut builder1 = InnerBuilder::new(0, false, TokenSetFormat::default());
2684        builder1.tokens.add("test".to_owned());
2685        builder1.posting_lists.push(PostingListBuilder::new(false));
2686        builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
2687        builder1.docs.append(100, 1); // row_id=100, num_tokens=1
2688        builder1.write(store.as_ref()).await.unwrap();
2689
2690        // Create second partition with one token and posting list length 4
2691        let mut builder2 = InnerBuilder::new(1, false, TokenSetFormat::default());
2692        builder2.tokens.add("test".to_owned()); // Use same token to test cache prefix fix
2693        builder2.posting_lists.push(PostingListBuilder::new(false));
2694        builder2.posting_lists[0].add(0, PositionRecorder::Count(2));
2695        builder2.posting_lists[0].add(1, PositionRecorder::Count(1));
2696        builder2.posting_lists[0].add(2, PositionRecorder::Count(3));
2697        builder2.posting_lists[0].add(3, PositionRecorder::Count(1));
2698        builder2.docs.append(200, 2); // row_id=200, num_tokens=2
2699        builder2.docs.append(201, 1); // row_id=201, num_tokens=1
2700        builder2.docs.append(202, 3); // row_id=202, num_tokens=3
2701        builder2.docs.append(203, 1); // row_id=203, num_tokens=1
2702        builder2.write(store.as_ref()).await.unwrap();
2703
2704        // Create metadata file with both partitions
2705        let metadata = std::collections::HashMap::from_iter(vec![
2706            (
2707                "partitions".to_owned(),
2708                serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
2709            ),
2710            (
2711                "params".to_owned(),
2712                serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
2713            ),
2714            (
2715                TOKEN_SET_FORMAT_KEY.to_owned(),
2716                TokenSetFormat::default().to_string(),
2717            ),
2718        ]);
2719        let mut writer = store
2720            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
2721            .await
2722            .unwrap();
2723        writer.finish_with_metadata(metadata).await.unwrap();
2724
2725        // Load the inverted index
2726        let cache = Arc::new(LanceCache::with_capacity(4096));
2727        let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
2728            .await
2729            .unwrap();
2730
2731        // Verify the index structure
2732        assert_eq!(index.partitions.len(), 2);
2733        assert_eq!(index.partitions[0].tokens.len(), 1);
2734        assert_eq!(index.partitions[1].tokens.len(), 1);
2735
2736        // Verify the partitions were loaded correctly
2737
2738        // Verify posting list lengths (note: partition order may differ from creation order)
2739        // Verify based on actual loading order
2740        if index.partitions[0].id() == 0 {
2741            // If partition[0] is ID=0, then it should have 1 document
2742            assert_eq!(index.partitions[0].inverted_list.posting_len(0), 1);
2743            assert_eq!(index.partitions[1].inverted_list.posting_len(0), 4);
2744            assert_eq!(index.partitions[0].docs.len(), 1);
2745            assert_eq!(index.partitions[1].docs.len(), 4);
2746        } else {
2747            // If partition[0] is ID=1, then it should have 4 documents
2748            assert_eq!(index.partitions[0].inverted_list.posting_len(0), 4);
2749            assert_eq!(index.partitions[1].inverted_list.posting_len(0), 1);
2750            assert_eq!(index.partitions[0].docs.len(), 4);
2751            assert_eq!(index.partitions[1].docs.len(), 1);
2752        }
2753
2754        // Prewarm the inverted index (this loads posting lists into cache)
2755        index.prewarm().await.unwrap();
2756
2757        let tokens = Arc::new(Tokens::new(vec!["test".to_string()], DocType::Text));
2758        let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
2759        let prefilter = Arc::new(NoFilter);
2760        let metrics = Arc::new(NoOpMetricsCollector);
2761
2762        let (row_ids, scores) = index
2763            .bm25_search(tokens, params, Operator::Or, prefilter, metrics)
2764            .await
2765            .unwrap();
2766
2767        // Verify that we got search results
2768        // Expected to find 5 documents: 1 from first partition, 4 from second partition
2769        assert_eq!(row_ids.len(), 5, "row_ids: {:?}", row_ids);
2770        assert!(!row_ids.is_empty(), "Should find at least some documents");
2771        assert_eq!(row_ids.len(), scores.len());
2772
2773        // All scores should be positive since all documents contain the search token
2774        for &score in &scores {
2775            assert!(score > 0.0, "All scores should be positive");
2776        }
2777
2778        // Check that we got results from both partitions
2779        assert!(
2780            row_ids.contains(&100),
2781            "Should contain row_id from partition 0"
2782        );
2783        assert!(
2784            row_ids.iter().any(|&id| id >= 200),
2785            "Should contain row_id from partition 1"
2786        );
2787    }
2788
2789    #[test]
2790    fn test_block_max_scores_capacity_matches_block_count() {
2791        let mut docs = DocSet::default();
2792        let num_docs = BLOCK_SIZE * 3 + 7;
2793        let doc_ids = (0..num_docs as u32).collect::<Vec<_>>();
2794        for doc_id in &doc_ids {
2795            docs.append(*doc_id as u64, 1);
2796        }
2797
2798        let freqs = vec![1_u32; doc_ids.len()];
2799        let block_max_scores = docs.calculate_block_max_scores(doc_ids.iter(), freqs.iter());
2800        let expected_blocks = doc_ids.len().div_ceil(BLOCK_SIZE);
2801
2802        assert_eq!(block_max_scores.len(), expected_blocks);
2803        assert_eq!(block_max_scores.capacity(), expected_blocks);
2804    }
2805
2806    #[tokio::test]
2807    async fn test_bm25_search_uses_global_idf() {
2808        let tmpdir = TempObjDir::default();
2809        let store = Arc::new(LanceIndexStore::new(
2810            ObjectStore::local().into(),
2811            tmpdir.clone(),
2812            Arc::new(LanceCache::no_cache()),
2813        ));
2814
2815        // Partition 0: 3 docs, only one contains "alpha".
2816        let mut builder0 = InnerBuilder::new(0, false, TokenSetFormat::default());
2817        builder0.tokens.add("alpha".to_owned());
2818        builder0.tokens.add("beta".to_owned());
2819        builder0.posting_lists.push(PostingListBuilder::new(false));
2820        builder0.posting_lists.push(PostingListBuilder::new(false));
2821        builder0.posting_lists[0].add(0, PositionRecorder::Count(1));
2822        builder0.posting_lists[1].add(1, PositionRecorder::Count(1));
2823        builder0.posting_lists[1].add(2, PositionRecorder::Count(1));
2824        builder0.docs.append(100, 1);
2825        builder0.docs.append(101, 1);
2826        builder0.docs.append(102, 1);
2827        builder0.write(store.as_ref()).await.unwrap();
2828
2829        // Partition 1: 1 doc, contains "alpha".
2830        let mut builder1 = InnerBuilder::new(1, false, TokenSetFormat::default());
2831        builder1.tokens.add("alpha".to_owned());
2832        builder1.posting_lists.push(PostingListBuilder::new(false));
2833        builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
2834        builder1.docs.append(200, 1);
2835        builder1.write(store.as_ref()).await.unwrap();
2836
2837        let metadata = std::collections::HashMap::from_iter(vec![
2838            (
2839                "partitions".to_owned(),
2840                serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
2841            ),
2842            (
2843                "params".to_owned(),
2844                serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
2845            ),
2846            (
2847                TOKEN_SET_FORMAT_KEY.to_owned(),
2848                TokenSetFormat::default().to_string(),
2849            ),
2850        ]);
2851        let mut writer = store
2852            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
2853            .await
2854            .unwrap();
2855        writer.finish_with_metadata(metadata).await.unwrap();
2856
2857        let cache = Arc::new(LanceCache::with_capacity(4096));
2858        let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
2859            .await
2860            .unwrap();
2861
2862        let tokens = Arc::new(Tokens::new(vec!["alpha".to_string()], DocType::Text));
2863        let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
2864        let prefilter = Arc::new(NoFilter);
2865        let metrics = Arc::new(NoOpMetricsCollector);
2866
2867        let (row_ids, scores) = index
2868            .bm25_search(tokens, params, Operator::Or, prefilter, metrics)
2869            .await
2870            .unwrap();
2871
2872        assert_eq!(row_ids.len(), 2);
2873        assert!(row_ids.contains(&100));
2874        assert!(row_ids.contains(&200));
2875        assert_eq!(row_ids.len(), scores.len());
2876
2877        let expected_idf = idf(2, 4);
2878        for score in scores {
2879            assert!(
2880                (score - expected_idf).abs() < 1e-6,
2881                "score: {}, expected: {}",
2882                score,
2883                expected_idf
2884            );
2885        }
2886    }
2887}