Skip to main content

lance_index/scalar/inverted/
index.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::fmt::{Debug, Display};
5use std::sync::Arc;
6use std::{
7    cmp::{min, Reverse},
8    collections::BinaryHeap,
9};
10use std::{collections::HashMap, ops::Range};
11
12use crate::metrics::NoOpMetricsCollector;
13use crate::prefilter::NoFilter;
14use crate::scalar::registry::{TrainingCriteria, TrainingOrdering};
15use arrow::datatypes::{self, Float32Type, Int32Type, UInt64Type};
16use arrow::{
17    array::{
18        AsArray, LargeBinaryBuilder, ListBuilder, StringBuilder, UInt32Builder, UInt64Builder,
19    },
20    buffer::OffsetBuffer,
21};
22use arrow::{buffer::ScalarBuffer, datatypes::UInt32Type};
23use arrow_array::{
24    Array, ArrayRef, BooleanArray, Float32Array, LargeBinaryArray, ListArray, OffsetSizeTrait,
25    RecordBatch, UInt32Array, UInt64Array,
26};
27use arrow_schema::{DataType, Field, Schema, SchemaRef};
28use async_trait::async_trait;
29use datafusion::execution::SendableRecordBatchStream;
30use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
31use datafusion_common::DataFusionError;
32use deepsize::DeepSizeOf;
33use fst::{Automaton, IntoStreamer, Streamer};
34use futures::{stream, FutureExt, StreamExt, TryStreamExt};
35use itertools::Itertools;
36use lance_arrow::{iter_str_array, RecordBatchExt};
37use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
38use lance_core::utils::mask::RowIdTreeMap;
39use lance_core::utils::{
40    mask::RowIdMask,
41    tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS},
42};
43use lance_core::{
44    container::list::ExpLinkedList,
45    utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
46};
47use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD};
48use roaring::RoaringBitmap;
49use snafu::location;
50use std::sync::LazyLock;
51use tokio::task::spawn_blocking;
52use tracing::{info, instrument};
53
54use super::{
55    builder::{
56        doc_file_path, inverted_list_schema, posting_file_path, token_file_path, ScoredDoc,
57        BLOCK_SIZE,
58    },
59    iter::PlainPostingListIterator,
60    query::*,
61    scorer::{idf, IndexBM25Scorer, Scorer, B, K1},
62};
63use super::{
64    builder::{InnerBuilder, PositionRecorder},
65    encoding::compress_posting_list,
66    iter::CompressedPostingListIterator,
67};
68use super::{
69    encoding::compress_positions,
70    iter::{PostingListIterator, TokenIterator, TokenSource},
71};
72use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams};
73use crate::frag_reuse::FragReuseIndex;
74use crate::pbold;
75use crate::scalar::inverted::lance_tokenizer::TextTokenizer;
76use crate::scalar::inverted::scorer::MemBM25Scorer;
77use crate::scalar::inverted::tokenizer::lance_tokenizer::LanceTokenizer;
78use crate::scalar::{
79    AnyQuery, BuiltinIndexType, CreatedIndex, IndexReader, IndexStore, MetricsCollector,
80    ScalarIndex, ScalarIndexParams, SearchResult, TokenQuery, UpdateCriteria,
81};
82use crate::Index;
83use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys};
84use std::str::FromStr;
85
86// Version 0: Arrow TokenSetFormat (legacy)
87// Version 1: Fst TokenSetFormat (new default, incompatible clients < 0.38)
88pub const INVERTED_INDEX_VERSION: u32 = 1;
89pub const TOKENS_FILE: &str = "tokens.lance";
90pub const INVERT_LIST_FILE: &str = "invert.lance";
91pub const DOCS_FILE: &str = "docs.lance";
92pub const METADATA_FILE: &str = "metadata.lance";
93
94pub const TOKEN_COL: &str = "_token";
95pub const TOKEN_ID_COL: &str = "_token_id";
96pub const TOKEN_FST_BYTES_COL: &str = "_token_fst_bytes";
97pub const TOKEN_NEXT_ID_COL: &str = "_token_next_id";
98pub const TOKEN_TOTAL_LENGTH_COL: &str = "_token_total_length";
99pub const FREQUENCY_COL: &str = "_frequency";
100pub const POSITION_COL: &str = "_position";
101pub const COMPRESSED_POSITION_COL: &str = "_compressed_position";
102pub const POSTING_COL: &str = "_posting";
103pub const MAX_SCORE_COL: &str = "_max_score";
104pub const LENGTH_COL: &str = "_length";
105pub const BLOCK_MAX_SCORE_COL: &str = "_block_max_score";
106pub const NUM_TOKEN_COL: &str = "_num_tokens";
107pub const SCORE_COL: &str = "_score";
108pub const TOKEN_SET_FORMAT_KEY: &str = "token_set_format";
109
110pub static SCORE_FIELD: LazyLock<Field> =
111    LazyLock::new(|| Field::new(SCORE_COL, DataType::Float32, true));
112pub static FTS_SCHEMA: LazyLock<SchemaRef> =
113    LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone(), SCORE_FIELD.clone()])));
114static ROW_ID_SCHEMA: LazyLock<SchemaRef> =
115    LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()])));
116
117#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)]
118pub enum TokenSetFormat {
119    Arrow,
120    #[default]
121    Fst,
122}
123
124impl Display for TokenSetFormat {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        match self {
127            Self::Arrow => f.write_str("arrow"),
128            Self::Fst => f.write_str("fst"),
129        }
130    }
131}
132
133impl FromStr for TokenSetFormat {
134    type Err = Error;
135
136    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
137        match s.trim() {
138            "" => Ok(Self::Arrow),
139            "arrow" => Ok(Self::Arrow),
140            "fst" => Ok(Self::Fst),
141            other => Err(Error::Index {
142                message: format!("unsupported token set format {}", other),
143                location: location!(),
144            }),
145        }
146    }
147}
148
149impl DeepSizeOf for TokenSetFormat {
150    fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
151        0
152    }
153}
154
155#[derive(Clone)]
156pub struct InvertedIndex {
157    params: InvertedIndexParams,
158    store: Arc<dyn IndexStore>,
159    tokenizer: Box<dyn LanceTokenizer>,
160    token_set_format: TokenSetFormat,
161    pub(crate) partitions: Vec<Arc<InvertedPartition>>,
162}
163
164impl Debug for InvertedIndex {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        f.debug_struct("InvertedIndex")
167            .field("params", &self.params)
168            .field("token_set_format", &self.token_set_format)
169            .field("partitions", &self.partitions)
170            .finish()
171    }
172}
173
174impl DeepSizeOf for InvertedIndex {
175    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
176        self.partitions.deep_size_of_children(context)
177    }
178}
179
180impl InvertedIndex {
181    fn to_builder(&self) -> InvertedIndexBuilder {
182        self.to_builder_with_offset(None)
183    }
184
185    fn to_builder_with_offset(&self, fragment_mask: Option<u64>) -> InvertedIndexBuilder {
186        if self.is_legacy() {
187            // for legacy format, we re-create the index in the new format
188            InvertedIndexBuilder::from_existing_index(
189                self.params.clone(),
190                None,
191                Vec::new(),
192                self.token_set_format,
193                fragment_mask,
194            )
195        } else {
196            let partitions = match fragment_mask {
197                Some(fragment_mask) => self
198                    .partitions
199                    .iter()
200                    // Filter partitions that belong to the specified fragment
201                    // The mask contains fragment_id in high 32 bits, we check if partition's
202                    // fragment_id matches by comparing the masked result with the original mask
203                    .filter(|part| part.belongs_to_fragment(fragment_mask))
204                    .map(|part| part.id())
205                    .collect(),
206                None => self.partitions.iter().map(|part| part.id()).collect(),
207            };
208
209            InvertedIndexBuilder::from_existing_index(
210                self.params.clone(),
211                Some(self.store.clone()),
212                partitions,
213                self.token_set_format,
214                fragment_mask,
215            )
216        }
217    }
218
219    pub fn tokenizer(&self) -> Box<dyn LanceTokenizer> {
220        self.tokenizer.clone()
221    }
222
223    pub fn params(&self) -> &InvertedIndexParams {
224        &self.params
225    }
226
227    // search the documents that contain the query
228    // return the row ids of the documents sorted by bm25 score
229    // ref: https://en.wikipedia.org/wiki/Okapi_BM25
230    // we first calculate in-partition BM25 scores,
231    // then re-calculate the scores for the top k documents across all partitions
232    #[instrument(level = "debug", skip_all)]
233    pub async fn bm25_search(
234        &self,
235        tokens: Arc<Tokens>,
236        params: Arc<FtsSearchParams>,
237        operator: Operator,
238        prefilter: Arc<dyn PreFilter>,
239        metrics: Arc<dyn MetricsCollector>,
240    ) -> Result<(Vec<u64>, Vec<f32>)> {
241        let limit = params.limit.unwrap_or(usize::MAX);
242        if limit == 0 {
243            return Ok((Vec::new(), Vec::new()));
244        }
245        let mask = prefilter.mask();
246
247        let mut candidates = BinaryHeap::new();
248        let parts = self
249            .partitions
250            .iter()
251            .map(|part| {
252                let part = part.clone();
253                let tokens = tokens.clone();
254                let params = params.clone();
255                let mask = mask.clone();
256                let metrics = metrics.clone();
257                async move {
258                    let postings = part
259                        .load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref())
260                        .await?;
261                    if postings.is_empty() {
262                        return Ok(Vec::new());
263                    }
264                    let params = params.clone();
265                    let mask = mask.clone();
266                    let metrics = metrics.clone();
267                    spawn_cpu(move || {
268                        part.bm25_search(
269                            params.as_ref(),
270                            operator,
271                            mask,
272                            postings,
273                            metrics.as_ref(),
274                        )
275                    })
276                    .await
277                }
278            })
279            .collect::<Vec<_>>();
280        let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
281        let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
282        while let Some(res) = parts.try_next().await? {
283            for DocCandidate {
284                row_id,
285                freqs,
286                doc_length,
287            } in res
288            {
289                let mut score = 0.0;
290                for (token, freq) in freqs.into_iter() {
291                    score += scorer.score(token.as_str(), freq, doc_length);
292                }
293                if candidates.len() < limit {
294                    candidates.push(Reverse(ScoredDoc::new(row_id, score)));
295                } else if candidates.peek().unwrap().0.score.0 < score {
296                    candidates.pop();
297                    candidates.push(Reverse(ScoredDoc::new(row_id, score)));
298                }
299            }
300        }
301
302        Ok(candidates
303            .into_sorted_vec()
304            .into_iter()
305            .map(|Reverse(doc)| (doc.row_id, doc.score.0))
306            .unzip())
307    }
308
309    async fn load_legacy_index(
310        store: Arc<dyn IndexStore>,
311        frag_reuse_index: Option<Arc<FragReuseIndex>>,
312        index_cache: &LanceCache,
313    ) -> Result<Arc<Self>> {
314        log::warn!("loading legacy FTS index");
315        let tokens_fut = tokio::spawn({
316            let store = store.clone();
317            async move {
318                let token_reader = store.open_index_file(TOKENS_FILE).await?;
319                let tokenizer = token_reader
320                    .schema()
321                    .metadata
322                    .get("tokenizer")
323                    .map(|s| serde_json::from_str::<InvertedIndexParams>(s))
324                    .transpose()?
325                    .unwrap_or_default();
326                let tokens = TokenSet::load(token_reader, TokenSetFormat::Arrow).await?;
327                Result::Ok((tokenizer, tokens))
328            }
329        });
330        let invert_list_fut = tokio::spawn({
331            let store = store.clone();
332            let index_cache_clone = index_cache.clone();
333            async move {
334                let invert_list_reader = store.open_index_file(INVERT_LIST_FILE).await?;
335                let invert_list =
336                    PostingListReader::try_new(invert_list_reader, &index_cache_clone).await?;
337                Result::Ok(Arc::new(invert_list))
338            }
339        });
340        let docs_fut = tokio::spawn({
341            let store = store.clone();
342            async move {
343                let docs_reader = store.open_index_file(DOCS_FILE).await?;
344                let docs = DocSet::load(docs_reader, true, frag_reuse_index).await?;
345                Result::Ok(docs)
346            }
347        });
348
349        let (tokenizer_config, tokens) = tokens_fut.await??;
350        let inverted_list = invert_list_fut.await??;
351        let docs = docs_fut.await??;
352
353        let tokenizer = tokenizer_config.build()?;
354
355        Ok(Arc::new(Self {
356            params: tokenizer_config,
357            store: store.clone(),
358            tokenizer,
359            token_set_format: TokenSetFormat::Arrow,
360            partitions: vec![Arc::new(InvertedPartition {
361                id: 0,
362                store,
363                tokens,
364                inverted_list,
365                docs,
366                token_set_format: TokenSetFormat::Arrow,
367            })],
368        }))
369    }
370
371    pub fn is_legacy(&self) -> bool {
372        self.partitions.len() == 1 && self.partitions[0].is_legacy()
373    }
374
375    pub async fn load(
376        store: Arc<dyn IndexStore>,
377        frag_reuse_index: Option<Arc<FragReuseIndex>>,
378        index_cache: &LanceCache,
379    ) -> Result<Arc<Self>>
380    where
381        Self: Sized,
382    {
383        // for new index format, there is a metadata file and multiple partitions,
384        // each partition is a separate index containing tokens, inverted list and docs.
385        // for old index format, there is no metadata file, and it's just like a single partition
386
387        match store.open_index_file(METADATA_FILE).await {
388            Ok(reader) => {
389                let params = reader.schema().metadata.get("params").ok_or(Error::Index {
390                    message: "params not found in metadata".to_owned(),
391                    location: location!(),
392                })?;
393                let params = serde_json::from_str::<InvertedIndexParams>(params)?;
394                let partitions =
395                    reader
396                        .schema()
397                        .metadata
398                        .get("partitions")
399                        .ok_or(Error::Index {
400                            message: "partitions not found in metadata".to_owned(),
401                            location: location!(),
402                        })?;
403                let partitions: Vec<u64> = serde_json::from_str(partitions)?;
404                let token_set_format = reader
405                    .schema()
406                    .metadata
407                    .get(TOKEN_SET_FORMAT_KEY)
408                    .map(|name| TokenSetFormat::from_str(name))
409                    .transpose()?
410                    .unwrap_or(TokenSetFormat::Arrow);
411
412                let format = token_set_format;
413                let partitions = partitions.into_iter().map(|id| {
414                    let store = store.clone();
415                    let frag_reuse_index_clone = frag_reuse_index.clone();
416                    let index_cache_for_part =
417                        index_cache.with_key_prefix(format!("part-{}", id).as_str());
418                    let token_set_format = format;
419                    async move {
420                        Result::Ok(Arc::new(
421                            InvertedPartition::load(
422                                store,
423                                id,
424                                frag_reuse_index_clone,
425                                &index_cache_for_part,
426                                token_set_format,
427                            )
428                            .await?,
429                        ))
430                    }
431                });
432                let partitions = stream::iter(partitions)
433                    .buffer_unordered(store.io_parallelism())
434                    .try_collect::<Vec<_>>()
435                    .await?;
436
437                let tokenizer = params.build()?;
438                Ok(Arc::new(Self {
439                    params,
440                    store,
441                    tokenizer,
442                    token_set_format,
443                    partitions,
444                }))
445            }
446            Err(_) => {
447                // old index format
448                Self::load_legacy_index(store, frag_reuse_index, index_cache).await
449            }
450        }
451    }
452}
453
454#[async_trait]
455impl Index for InvertedIndex {
456    fn as_any(&self) -> &dyn std::any::Any {
457        self
458    }
459
460    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
461        self
462    }
463
464    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
465        Err(Error::invalid_input(
466            "inverted index cannot be cast to vector index",
467            location!(),
468        ))
469    }
470
471    fn statistics(&self) -> Result<serde_json::Value> {
472        let num_tokens = self
473            .partitions
474            .iter()
475            .map(|part| part.tokens.len())
476            .sum::<usize>();
477        let num_docs = self
478            .partitions
479            .iter()
480            .map(|part| part.docs.len())
481            .sum::<usize>();
482        Ok(serde_json::json!({
483            "params": self.params,
484            "num_tokens": num_tokens,
485            "num_docs": num_docs,
486        }))
487    }
488
489    async fn prewarm(&self) -> Result<()> {
490        for part in &self.partitions {
491            part.inverted_list.prewarm().await?;
492        }
493        Ok(())
494    }
495
496    fn index_type(&self) -> crate::IndexType {
497        crate::IndexType::Inverted
498    }
499
500    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
501        unimplemented!()
502    }
503}
504
505impl InvertedIndex {
506    /// Search docs match the input text.
507    async fn do_search(&self, text: &str) -> Result<RecordBatch> {
508        let params = FtsSearchParams::new();
509        let mut tokenizer = self.tokenizer.clone();
510        let tokens = collect_query_tokens(text, &mut tokenizer, None);
511
512        let (doc_ids, _) = self
513            .bm25_search(
514                Arc::new(tokens),
515                params.into(),
516                Operator::And,
517                Arc::new(NoFilter),
518                Arc::new(NoOpMetricsCollector),
519            )
520            .boxed()
521            .await?;
522
523        Ok(RecordBatch::try_new(
524            ROW_ID_SCHEMA.clone(),
525            vec![Arc::new(UInt64Array::from(doc_ids))],
526        )?)
527    }
528}
529
530#[async_trait]
531impl ScalarIndex for InvertedIndex {
532    // return the row ids of the documents that contain the query
533    #[instrument(level = "debug", skip_all)]
534    async fn search(
535        &self,
536        query: &dyn AnyQuery,
537        _metrics: &dyn MetricsCollector,
538    ) -> Result<SearchResult> {
539        let query = query.as_any().downcast_ref::<TokenQuery>().unwrap();
540
541        match query {
542            TokenQuery::TokensContains(text) => {
543                let records = self.do_search(text).await?;
544                let row_ids = records
545                    .column(0)
546                    .as_any()
547                    .downcast_ref::<UInt64Array>()
548                    .unwrap();
549                let row_ids = row_ids.iter().flatten().collect_vec();
550                Ok(SearchResult::AtMost(RowIdTreeMap::from_iter(row_ids)))
551            }
552        }
553    }
554
555    fn can_remap(&self) -> bool {
556        true
557    }
558
559    async fn remap(
560        &self,
561        mapping: &HashMap<u64, Option<u64>>,
562        dest_store: &dyn IndexStore,
563    ) -> Result<CreatedIndex> {
564        self.to_builder()
565            .remap(mapping, self.store.clone(), dest_store)
566            .await?;
567
568        let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
569
570        // Use version 0 for Arrow format (legacy), version 1 for Fst format (new)
571        let index_version = match self.token_set_format {
572            TokenSetFormat::Arrow => 0,
573            TokenSetFormat::Fst => INVERTED_INDEX_VERSION,
574        };
575
576        Ok(CreatedIndex {
577            index_details: prost_types::Any::from_msg(&details).unwrap(),
578            index_version,
579        })
580    }
581
582    async fn update(
583        &self,
584        new_data: SendableRecordBatchStream,
585        dest_store: &dyn IndexStore,
586    ) -> Result<CreatedIndex> {
587        self.to_builder().update(new_data, dest_store).await?;
588
589        let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
590
591        // Use version 0 for Arrow format (legacy), version 1 for Fst format (new)
592        let index_version = match self.token_set_format {
593            TokenSetFormat::Arrow => 0,
594            TokenSetFormat::Fst => INVERTED_INDEX_VERSION,
595        };
596
597        Ok(CreatedIndex {
598            index_details: prost_types::Any::from_msg(&details).unwrap(),
599            index_version,
600        })
601    }
602
603    fn update_criteria(&self) -> UpdateCriteria {
604        let criteria = TrainingCriteria::new(TrainingOrdering::None).with_row_id();
605        if self.is_legacy() {
606            UpdateCriteria::requires_old_data(criteria)
607        } else {
608            UpdateCriteria::only_new_data(criteria)
609        }
610    }
611
612    fn derive_index_params(&self) -> Result<ScalarIndexParams> {
613        let mut params = self.params.clone();
614        if params.base_tokenizer.is_empty() {
615            params.base_tokenizer = "simple".to_string();
616        }
617
618        let params_json = serde_json::to_string(&params)?;
619
620        Ok(ScalarIndexParams {
621            index_type: BuiltinIndexType::Inverted.as_str().to_string(),
622            params: Some(params_json),
623        })
624    }
625}
626
627#[derive(Debug, Clone, DeepSizeOf)]
628pub struct InvertedPartition {
629    // 0 for legacy format
630    id: u64,
631    store: Arc<dyn IndexStore>,
632    pub(crate) tokens: TokenSet,
633    pub(crate) inverted_list: Arc<PostingListReader>,
634    pub(crate) docs: DocSet,
635    token_set_format: TokenSetFormat,
636}
637
638impl InvertedPartition {
639    /// Check if this partition belongs to the specified fragment.
640    ///
641    /// This method encapsulates the bit manipulation logic for fragment filtering
642    /// in distributed indexing scenarios.
643    ///
644    /// # Arguments
645    /// * `fragment_mask` - A mask with fragment_id in high 32 bits
646    ///
647    /// # Returns
648    /// * `true` if the partition belongs to the fragment, `false` otherwise
649    pub fn belongs_to_fragment(&self, fragment_mask: u64) -> bool {
650        (self.id() & fragment_mask) == fragment_mask
651    }
652
653    pub fn id(&self) -> u64 {
654        self.id
655    }
656
657    pub fn store(&self) -> &dyn IndexStore {
658        self.store.as_ref()
659    }
660
661    pub fn is_legacy(&self) -> bool {
662        self.inverted_list.lengths.is_none()
663    }
664
665    pub async fn load(
666        store: Arc<dyn IndexStore>,
667        id: u64,
668        frag_reuse_index: Option<Arc<FragReuseIndex>>,
669        index_cache: &LanceCache,
670        token_set_format: TokenSetFormat,
671    ) -> Result<Self> {
672        let token_file = store.open_index_file(&token_file_path(id)).await?;
673        let tokens = TokenSet::load(token_file, token_set_format).await?;
674        let invert_list_file = store.open_index_file(&posting_file_path(id)).await?;
675        let inverted_list = PostingListReader::try_new(invert_list_file, index_cache).await?;
676        let docs_file = store.open_index_file(&doc_file_path(id)).await?;
677        let docs = DocSet::load(docs_file, false, frag_reuse_index).await?;
678
679        Ok(Self {
680            id,
681            store,
682            tokens,
683            inverted_list: Arc::new(inverted_list),
684            docs,
685            token_set_format,
686        })
687    }
688
689    fn map(&self, token: &str) -> Option<u32> {
690        self.tokens.get(token)
691    }
692
693    pub fn expand_fuzzy(&self, tokens: &Tokens, params: &FtsSearchParams) -> Result<Tokens> {
694        let mut new_tokens = Vec::with_capacity(min(tokens.len(), params.max_expansions));
695        for token in tokens {
696            let fuzziness = match params.fuzziness {
697                Some(fuzziness) => fuzziness,
698                None => MatchQuery::auto_fuzziness(token),
699            };
700            let lev =
701                fst::automaton::Levenshtein::new(token, fuzziness).map_err(|e| Error::Index {
702                    message: format!("failed to construct the fuzzy query: {}", e),
703                    location: location!(),
704                })?;
705
706            let base_len = tokens.token_type().prefix_len(token) as u32;
707            if let TokenMap::Fst(ref map) = self.tokens.tokens {
708                match base_len + params.prefix_length {
709                    0 => take_fst_keys(map.search(lev), &mut new_tokens, params.max_expansions),
710                    prefix_length => {
711                        let prefix = &token[..min(prefix_length as usize, token.len())];
712                        let prefix = fst::automaton::Str::new(prefix).starts_with();
713                        take_fst_keys(
714                            map.search(lev.intersection(prefix)),
715                            &mut new_tokens,
716                            params.max_expansions,
717                        )
718                    }
719                }
720            } else {
721                return Err(Error::Index {
722                    message: "tokens is not fst, which is not expected".to_owned(),
723                    location: location!(),
724                });
725            }
726        }
727        Ok(Tokens::new(new_tokens, tokens.token_type().clone()))
728    }
729
730    // search the documents that contain the query
731    // return the doc info and the doc length
732    // ref: https://en.wikipedia.org/wiki/Okapi_BM25
733    #[instrument(level = "debug", skip_all)]
734    pub async fn load_posting_lists(
735        &self,
736        tokens: &Tokens,
737        params: &FtsSearchParams,
738        metrics: &dyn MetricsCollector,
739    ) -> Result<Vec<PostingIterator>> {
740        let is_fuzzy = matches!(params.fuzziness, Some(n) if n != 0);
741        let is_phrase_query = params.phrase_slop.is_some();
742        let tokens = match is_fuzzy {
743            true => self.expand_fuzzy(tokens, params)?,
744            false => tokens.clone(),
745        };
746        let mut token_ids = Vec::with_capacity(tokens.len());
747        for token in tokens {
748            let token_id = self.map(&token);
749            if let Some(token_id) = token_id {
750                token_ids.push((token_id, token));
751            } else if is_phrase_query {
752                // if the token is not found, we can't do phrase query
753                return Ok(Vec::new());
754            }
755        }
756        if token_ids.is_empty() {
757            return Ok(Vec::new());
758        }
759        if !is_phrase_query {
760            // remove duplicates
761            token_ids.sort_unstable_by_key(|(token_id, _)| *token_id);
762            token_ids.dedup_by_key(|(token_id, _)| *token_id);
763        }
764
765        let num_docs = self.docs.len();
766        stream::iter(token_ids)
767            .enumerate()
768            .map(|(position, (token_id, token))| async move {
769                let posting = self
770                    .inverted_list
771                    .posting_list(token_id, is_phrase_query, metrics)
772                    .await?;
773
774                Result::Ok(PostingIterator::new(
775                    token,
776                    token_id,
777                    position as u32,
778                    posting,
779                    num_docs,
780                ))
781            })
782            .buffered(self.store.io_parallelism())
783            .try_collect::<Vec<_>>()
784            .await
785    }
786
787    #[instrument(level = "debug", skip_all)]
788    pub fn bm25_search(
789        &self,
790        params: &FtsSearchParams,
791        operator: Operator,
792        mask: Arc<RowIdMask>,
793        postings: Vec<PostingIterator>,
794        metrics: &dyn MetricsCollector,
795    ) -> Result<Vec<DocCandidate>> {
796        if postings.is_empty() {
797            return Ok(Vec::new());
798        }
799
800        // let local_metrics = LocalMetricsCollector::default();
801        let scorer = IndexBM25Scorer::new(std::iter::once(self));
802        let mut wand = Wand::new(operator, postings.into_iter(), &self.docs, scorer);
803        let hits = wand.search(params, mask, metrics)?;
804        // local_metrics.dump_into(metrics);
805        Ok(hits)
806    }
807
808    pub async fn into_builder(self) -> Result<InnerBuilder> {
809        let mut builder = InnerBuilder::new(
810            self.id,
811            self.inverted_list.has_positions(),
812            self.token_set_format,
813        );
814        builder.tokens = self.tokens;
815        builder.docs = self.docs;
816
817        builder
818            .posting_lists
819            .reserve_exact(self.inverted_list.len());
820        for posting_list in self
821            .inverted_list
822            .read_all(self.inverted_list.has_positions())
823            .await?
824        {
825            let posting_list = posting_list?;
826            builder
827                .posting_lists
828                .push(posting_list.into_builder(&builder.docs));
829        }
830        Ok(builder)
831    }
832}
833
834// at indexing, we use HashMap because we need it to be mutable,
835// at searching, we use fst::Map because it's more efficient
836#[derive(Debug, Clone)]
837pub enum TokenMap {
838    HashMap(HashMap<String, u32>),
839    Fst(fst::Map<Vec<u8>>),
840}
841
842impl Default for TokenMap {
843    fn default() -> Self {
844        Self::HashMap(HashMap::new())
845    }
846}
847
848impl DeepSizeOf for TokenMap {
849    fn deep_size_of_children(&self, ctx: &mut deepsize::Context) -> usize {
850        match self {
851            Self::HashMap(map) => map.deep_size_of_children(ctx),
852            Self::Fst(map) => map.as_fst().size(),
853        }
854    }
855}
856
857impl TokenMap {
858    pub fn len(&self) -> usize {
859        match self {
860            Self::HashMap(map) => map.len(),
861            Self::Fst(map) => map.len(),
862        }
863    }
864
865    pub fn is_empty(&self) -> bool {
866        self.len() == 0
867    }
868}
869
870// TokenSet is a mapping from tokens to token ids
871#[derive(Debug, Clone, Default, DeepSizeOf)]
872pub struct TokenSet {
873    // token -> token_id
874    pub(crate) tokens: TokenMap,
875    pub(crate) next_id: u32,
876    total_length: usize,
877}
878
879impl TokenSet {
880    pub fn into_mut(self) -> Self {
881        let tokens = match self.tokens {
882            TokenMap::HashMap(map) => map,
883            TokenMap::Fst(map) => {
884                let mut new_map = HashMap::with_capacity(map.len());
885                let mut stream = map.into_stream();
886                while let Some((token, token_id)) = stream.next() {
887                    new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
888                }
889
890                new_map
891            }
892        };
893
894        Self {
895            tokens: TokenMap::HashMap(tokens),
896            next_id: self.next_id,
897            total_length: self.total_length,
898        }
899    }
900
901    pub fn len(&self) -> usize {
902        self.tokens.len()
903    }
904
905    pub fn is_empty(&self) -> bool {
906        self.len() == 0
907    }
908
909    pub(crate) fn iter(&self) -> TokenIterator<'_> {
910        TokenIterator::new(match &self.tokens {
911            TokenMap::HashMap(map) => TokenSource::HashMap(map.iter()),
912            TokenMap::Fst(map) => TokenSource::Fst(map.stream()),
913        })
914    }
915
916    pub fn to_batch(self, format: TokenSetFormat) -> Result<RecordBatch> {
917        match format {
918            TokenSetFormat::Arrow => self.into_arrow_batch(),
919            TokenSetFormat::Fst => self.into_fst_batch(),
920        }
921    }
922
923    fn into_arrow_batch(self) -> Result<RecordBatch> {
924        let mut token_builder = StringBuilder::with_capacity(self.tokens.len(), self.total_length);
925        let mut token_id_builder = UInt32Builder::with_capacity(self.tokens.len());
926
927        match self.tokens {
928            TokenMap::Fst(map) => {
929                let mut stream = map.stream();
930                while let Some((token, token_id)) = stream.next() {
931                    token_builder.append_value(String::from_utf8_lossy(token));
932                    token_id_builder.append_value(token_id as u32);
933                }
934            }
935            TokenMap::HashMap(map) => {
936                for (token, token_id) in map.into_iter().sorted_unstable() {
937                    token_builder.append_value(token);
938                    token_id_builder.append_value(token_id);
939                }
940            }
941        }
942
943        let token_col = token_builder.finish();
944        let token_id_col = token_id_builder.finish();
945
946        let schema = arrow_schema::Schema::new(vec![
947            arrow_schema::Field::new(TOKEN_COL, DataType::Utf8, false),
948            arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false),
949        ]);
950
951        let batch = RecordBatch::try_new(
952            Arc::new(schema),
953            vec![
954                Arc::new(token_col) as ArrayRef,
955                Arc::new(token_id_col) as ArrayRef,
956            ],
957        )?;
958        Ok(batch)
959    }
960
961    fn into_fst_batch(mut self) -> Result<RecordBatch> {
962        let fst_map = match std::mem::take(&mut self.tokens) {
963            TokenMap::Fst(map) => map,
964            TokenMap::HashMap(map) => Self::build_fst_from_map(map)?,
965        };
966        let bytes = fst_map.into_fst().into_inner();
967
968        let mut fst_builder = LargeBinaryBuilder::with_capacity(1, bytes.len());
969        fst_builder.append_value(bytes);
970        let fst_col = fst_builder.finish();
971
972        let mut next_id_builder = UInt32Builder::with_capacity(1);
973        next_id_builder.append_value(self.next_id);
974        let next_id_col = next_id_builder.finish();
975
976        let mut total_length_builder = UInt64Builder::with_capacity(1);
977        total_length_builder.append_value(self.total_length as u64);
978        let total_length_col = total_length_builder.finish();
979
980        let schema = arrow_schema::Schema::new(vec![
981            arrow_schema::Field::new(TOKEN_FST_BYTES_COL, DataType::LargeBinary, false),
982            arrow_schema::Field::new(TOKEN_NEXT_ID_COL, DataType::UInt32, false),
983            arrow_schema::Field::new(TOKEN_TOTAL_LENGTH_COL, DataType::UInt64, false),
984        ]);
985
986        let batch = RecordBatch::try_new(
987            Arc::new(schema),
988            vec![
989                Arc::new(fst_col) as ArrayRef,
990                Arc::new(next_id_col) as ArrayRef,
991                Arc::new(total_length_col) as ArrayRef,
992            ],
993        )?;
994        Ok(batch)
995    }
996
997    fn build_fst_from_map(map: HashMap<String, u32>) -> Result<fst::Map<Vec<u8>>> {
998        let mut entries: Vec<_> = map.into_iter().collect();
999        entries.sort_unstable_by(|(lhs, _), (rhs, _)| lhs.cmp(rhs));
1000        let mut builder = fst::MapBuilder::memory();
1001        for (token, token_id) in entries {
1002            builder
1003                .insert(&token, token_id as u64)
1004                .map_err(|e| Error::Index {
1005                    message: format!("failed to insert token {}: {}", token, e),
1006                    location: location!(),
1007                })?;
1008        }
1009        Ok(builder.into_map())
1010    }
1011
1012    pub async fn load(reader: Arc<dyn IndexReader>, format: TokenSetFormat) -> Result<Self> {
1013        match format {
1014            TokenSetFormat::Arrow => Self::load_arrow(reader).await,
1015            TokenSetFormat::Fst => Self::load_fst(reader).await,
1016        }
1017    }
1018
1019    async fn load_arrow(reader: Arc<dyn IndexReader>) -> Result<Self> {
1020        let batch = reader.read_range(0..reader.num_rows(), None).await?;
1021
1022        let (tokens, next_id, total_length) = spawn_blocking(move || {
1023            let mut next_id = 0;
1024            let mut total_length = 0;
1025            let mut tokens = fst::MapBuilder::memory();
1026
1027            let token_col = batch[TOKEN_COL].as_string::<i32>();
1028            let token_id_col = batch[TOKEN_ID_COL].as_primitive::<datatypes::UInt32Type>();
1029
1030            for (token, &token_id) in token_col.iter().zip(token_id_col.values().iter()) {
1031                let token = token.ok_or(Error::Index {
1032                    message: "found null token in token set".to_owned(),
1033                    location: location!(),
1034                })?;
1035                next_id = next_id.max(token_id + 1);
1036                total_length += token.len();
1037                tokens
1038                    .insert(token, token_id as u64)
1039                    .map_err(|e| Error::Index {
1040                        message: format!("failed to insert token {}: {}", token, e),
1041                        location: location!(),
1042                    })?;
1043            }
1044
1045            Ok::<_, Error>((tokens.into_map(), next_id, total_length))
1046        })
1047        .await
1048        .map_err(|err| Error::Execution {
1049            message: format!("failed to spawn blocking task: {}", err),
1050            location: location!(),
1051        })??;
1052
1053        Ok(Self {
1054            tokens: TokenMap::Fst(tokens),
1055            next_id,
1056            total_length,
1057        })
1058    }
1059
1060    async fn load_fst(reader: Arc<dyn IndexReader>) -> Result<Self> {
1061        let batch = reader.read_range(0..reader.num_rows(), None).await?;
1062        if batch.num_rows() == 0 {
1063            return Err(Error::Index {
1064                message: "token set batch is empty".to_owned(),
1065                location: location!(),
1066            });
1067        }
1068
1069        let fst_col = batch[TOKEN_FST_BYTES_COL].as_binary::<i64>();
1070        let bytes = fst_col.value(0);
1071        let map = fst::Map::new(bytes.to_vec()).map_err(|e| Error::Index {
1072            message: format!("failed to load fst tokens: {}", e),
1073            location: location!(),
1074        })?;
1075
1076        let next_id_col = batch[TOKEN_NEXT_ID_COL].as_primitive::<datatypes::UInt32Type>();
1077        let total_length_col =
1078            batch[TOKEN_TOTAL_LENGTH_COL].as_primitive::<datatypes::UInt64Type>();
1079
1080        let next_id = next_id_col.values().first().copied().ok_or(Error::Index {
1081            message: "token next id column is empty".to_owned(),
1082            location: location!(),
1083        })?;
1084
1085        let total_length = total_length_col
1086            .values()
1087            .first()
1088            .copied()
1089            .ok_or(Error::Index {
1090                message: "token total length column is empty".to_owned(),
1091                location: location!(),
1092            })?;
1093
1094        Ok(Self {
1095            tokens: TokenMap::Fst(map),
1096            next_id,
1097            total_length: usize::try_from(total_length).map_err(|_| Error::Index {
1098                message: format!("token total length {} overflows usize", total_length),
1099                location: location!(),
1100            })?,
1101        })
1102    }
1103
1104    pub fn add(&mut self, token: String) -> u32 {
1105        let next_id = self.next_id();
1106        let len = token.len();
1107        let token_id = match self.tokens {
1108            TokenMap::HashMap(ref mut map) => *map.entry(token).or_insert(next_id),
1109            _ => unreachable!("tokens must be HashMap while indexing"),
1110        };
1111
1112        // add token if it doesn't exist
1113        if token_id == next_id {
1114            self.next_id += 1;
1115            self.total_length += len;
1116        }
1117
1118        token_id
1119    }
1120
1121    pub fn get(&self, token: &str) -> Option<u32> {
1122        match self.tokens {
1123            TokenMap::HashMap(ref map) => map.get(token).copied(),
1124            TokenMap::Fst(ref map) => map.get(token).map(|id| id as u32),
1125        }
1126    }
1127
1128    // the `removed_token_ids` must be sorted
1129    pub fn remap(&mut self, removed_token_ids: &[u32]) {
1130        if removed_token_ids.is_empty() {
1131            return;
1132        }
1133
1134        let mut map = match std::mem::take(&mut self.tokens) {
1135            TokenMap::HashMap(map) => map,
1136            TokenMap::Fst(map) => {
1137                let mut new_map = HashMap::with_capacity(map.len());
1138                let mut stream = map.into_stream();
1139                while let Some((token, token_id)) = stream.next() {
1140                    new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
1141                }
1142
1143                new_map
1144            }
1145        };
1146
1147        map.retain(
1148            |_, token_id| match removed_token_ids.binary_search(token_id) {
1149                Ok(_) => false,
1150                Err(index) => {
1151                    *token_id -= index as u32;
1152                    true
1153                }
1154            },
1155        );
1156
1157        self.tokens = TokenMap::HashMap(map);
1158    }
1159
1160    pub fn next_id(&self) -> u32 {
1161        self.next_id
1162    }
1163}
1164
1165pub struct PostingListReader {
1166    reader: Arc<dyn IndexReader>,
1167
1168    // legacy format only
1169    offsets: Option<Vec<usize>>,
1170
1171    // from metadata for legacy format
1172    // from column for new format
1173    max_scores: Option<Vec<f32>>,
1174
1175    // new format only
1176    lengths: Option<Vec<u32>>,
1177
1178    has_position: bool,
1179
1180    index_cache: WeakLanceCache,
1181}
1182
1183impl std::fmt::Debug for PostingListReader {
1184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1185        f.debug_struct("InvertedListReader")
1186            .field("offsets", &self.offsets)
1187            .field("max_scores", &self.max_scores)
1188            .finish()
1189    }
1190}
1191
1192impl DeepSizeOf for PostingListReader {
1193    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
1194        self.offsets.deep_size_of_children(context)
1195            + self.max_scores.deep_size_of_children(context)
1196            + self.lengths.deep_size_of_children(context)
1197    }
1198}
1199
1200impl PostingListReader {
1201    pub(crate) async fn try_new(
1202        reader: Arc<dyn IndexReader>,
1203        index_cache: &LanceCache,
1204    ) -> Result<Self> {
1205        let has_position = reader.schema().field(POSITION_COL).is_some();
1206        let (offsets, max_scores, lengths) = if reader.schema().field(POSTING_COL).is_none() {
1207            let (offsets, max_scores) = Self::load_metadata(reader.schema())?;
1208            (Some(offsets), max_scores, None)
1209        } else {
1210            let metadata = reader
1211                .read_range(0..reader.num_rows(), Some(&[MAX_SCORE_COL, LENGTH_COL]))
1212                .await?;
1213            let max_scores = metadata[MAX_SCORE_COL]
1214                .as_primitive::<Float32Type>()
1215                .values()
1216                .to_vec();
1217            let lengths = metadata[LENGTH_COL]
1218                .as_primitive::<UInt32Type>()
1219                .values()
1220                .to_vec();
1221            (None, Some(max_scores), Some(lengths))
1222        };
1223
1224        Ok(Self {
1225            reader,
1226            offsets,
1227            max_scores,
1228            lengths,
1229            has_position,
1230            index_cache: WeakLanceCache::from(index_cache),
1231        })
1232    }
1233
1234    // for legacy format
1235    // returns the offsets and max scores
1236    fn load_metadata(
1237        schema: &lance_core::datatypes::Schema,
1238    ) -> Result<(Vec<usize>, Option<Vec<f32>>)> {
1239        let offsets = schema.metadata.get("offsets").ok_or(Error::Index {
1240            message: "offsets not found in metadata".to_owned(),
1241            location: location!(),
1242        })?;
1243        let offsets = serde_json::from_str(offsets)?;
1244
1245        let max_scores = schema
1246            .metadata
1247            .get("max_scores")
1248            .map(|max_scores| serde_json::from_str(max_scores))
1249            .transpose()?;
1250        Ok((offsets, max_scores))
1251    }
1252
1253    // the number of posting lists
1254    pub fn len(&self) -> usize {
1255        match self.offsets {
1256            Some(ref offsets) => offsets.len(),
1257            None => self.reader.num_rows(),
1258        }
1259    }
1260
1261    pub fn is_empty(&self) -> bool {
1262        self.len() == 0
1263    }
1264
1265    pub(crate) fn has_positions(&self) -> bool {
1266        self.has_position
1267    }
1268
1269    pub(crate) fn posting_len(&self, token_id: u32) -> usize {
1270        let token_id = token_id as usize;
1271
1272        match self.offsets {
1273            Some(ref offsets) => {
1274                let next_offset = offsets
1275                    .get(token_id + 1)
1276                    .copied()
1277                    .unwrap_or(self.reader.num_rows());
1278                next_offset - offsets[token_id]
1279            }
1280            None => {
1281                if let Some(lengths) = &self.lengths {
1282                    lengths[token_id] as usize
1283                } else {
1284                    panic!("posting list reader is not initialized")
1285                }
1286            }
1287        }
1288    }
1289
1290    pub(crate) async fn posting_batch(
1291        &self,
1292        token_id: u32,
1293        with_position: bool,
1294    ) -> Result<RecordBatch> {
1295        if self.offsets.is_some() {
1296            self.posting_batch_legacy(token_id, with_position).await
1297        } else {
1298            let token_id = token_id as usize;
1299            let columns = if with_position {
1300                vec![POSTING_COL, POSITION_COL]
1301            } else {
1302                vec![POSTING_COL]
1303            };
1304            let batch = self
1305                .reader
1306                .read_range(token_id..token_id + 1, Some(&columns))
1307                .await?;
1308            Ok(batch)
1309        }
1310    }
1311
1312    async fn posting_batch_legacy(
1313        &self,
1314        token_id: u32,
1315        with_position: bool,
1316    ) -> Result<RecordBatch> {
1317        let mut columns = vec![ROW_ID, FREQUENCY_COL];
1318        if with_position {
1319            columns.push(POSITION_COL);
1320        }
1321
1322        let length = self.posting_len(token_id);
1323        let token_id = token_id as usize;
1324        let offset = self.offsets.as_ref().unwrap()[token_id];
1325        let batch = self
1326            .reader
1327            .read_range(offset..offset + length, Some(&columns))
1328            .await?;
1329        Ok(batch)
1330    }
1331
1332    #[instrument(level = "debug", skip(self, metrics))]
1333    pub(crate) async fn posting_list(
1334        &self,
1335        token_id: u32,
1336        is_phrase_query: bool,
1337        metrics: &dyn MetricsCollector,
1338    ) -> Result<PostingList> {
1339        let cache_key = PostingListKey { token_id };
1340        let mut posting = self
1341            .index_cache
1342            .get_or_insert_with_key(cache_key, || async move {
1343                metrics.record_part_load();
1344                info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="inverted", part_id=token_id);
1345                let batch = self.posting_batch(token_id, false).await?;
1346                self.posting_list_from_batch(&batch, token_id)
1347            })
1348            .await
1349            .map_err(|e| Error::io(e.to_string(), location!()))?
1350            .as_ref()
1351            .clone();
1352
1353        if is_phrase_query {
1354            // hit the cache and when the cache was populated, the positions column was not loaded
1355            let positions = self.read_positions(token_id).await?;
1356            posting.set_positions(positions);
1357        }
1358
1359        Ok(posting)
1360    }
1361
1362    pub(crate) fn posting_list_from_batch(
1363        &self,
1364        batch: &RecordBatch,
1365        token_id: u32,
1366    ) -> Result<PostingList> {
1367        let posting_list = PostingList::from_batch(
1368            batch,
1369            self.max_scores
1370                .as_ref()
1371                .map(|max_scores| max_scores[token_id as usize]),
1372            self.lengths
1373                .as_ref()
1374                .map(|lengths| lengths[token_id as usize]),
1375        )?;
1376        Ok(posting_list)
1377    }
1378
1379    async fn prewarm(&self) -> Result<()> {
1380        let batch = self.read_batch(false).await?;
1381        for token_id in 0..self.len() {
1382            let posting_range = self.posting_list_range(token_id as u32);
1383            let batch = batch.slice(posting_range.start, posting_range.end - posting_range.start);
1384            // Apply shrink_to_fit to create a deep copy with compacted buffers
1385            // This ensures each cached entry has its own memory, not shared references
1386            let batch = batch.shrink_to_fit()?;
1387            let posting_list = self.posting_list_from_batch(&batch, token_id as u32)?;
1388            let inserted = self
1389                .index_cache
1390                .insert_with_key(
1391                    &PostingListKey {
1392                        token_id: token_id as u32,
1393                    },
1394                    Arc::new(posting_list),
1395                )
1396                .await;
1397
1398            if !inserted {
1399                return Err(Error::Internal {
1400                    message: "Failed to prewarm index: cache is no longer available".to_string(),
1401                    location: location!(),
1402                });
1403            }
1404        }
1405
1406        Ok(())
1407    }
1408
1409    pub(crate) async fn read_batch(&self, with_position: bool) -> Result<RecordBatch> {
1410        let columns = self.posting_columns(with_position);
1411        let batch = self
1412            .reader
1413            .read_range(0..self.reader.num_rows(), Some(&columns))
1414            .await?;
1415        Ok(batch)
1416    }
1417
1418    pub(crate) async fn read_all(
1419        &self,
1420        with_position: bool,
1421    ) -> Result<impl Iterator<Item = Result<PostingList>> + '_> {
1422        let batch = self.read_batch(with_position).await?;
1423        Ok((0..self.len()).map(move |i| {
1424            let token_id = i as u32;
1425            let range = self.posting_list_range(token_id);
1426            let batch = batch.slice(i, range.end - range.start);
1427            self.posting_list_from_batch(&batch, token_id)
1428        }))
1429    }
1430
1431    async fn read_positions(&self, token_id: u32) -> Result<ListArray> {
1432        let positions = self.index_cache.get_or_insert_with_key(PositionKey { token_id }, || async move {
1433            let batch = self
1434                .reader
1435                .read_range(self.posting_list_range(token_id), Some(&[POSITION_COL]))
1436                .await.map_err(|e| {
1437                    match e {
1438                        Error::Schema { .. } => Error::invalid_input(
1439                            "position is not found but required for phrase queries, try recreating the index with position".to_owned(),
1440                            location!(),
1441                        ),
1442                        e => e
1443                    }
1444                })?;
1445            Result::Ok(Positions(batch[POSITION_COL]
1446                .as_list::<i32>()
1447                .clone()))
1448        }).await?;
1449        Ok(positions.0.clone())
1450    }
1451
1452    fn posting_list_range(&self, token_id: u32) -> Range<usize> {
1453        match self.offsets {
1454            Some(ref offsets) => {
1455                let offset = offsets[token_id as usize];
1456                let posting_len = self.posting_len(token_id);
1457                offset..offset + posting_len
1458            }
1459            None => {
1460                let token_id = token_id as usize;
1461                token_id..token_id + 1
1462            }
1463        }
1464    }
1465
1466    fn posting_columns(&self, with_position: bool) -> Vec<&'static str> {
1467        let mut base_columns = match self.offsets {
1468            Some(_) => vec![ROW_ID, FREQUENCY_COL],
1469            None => vec![POSTING_COL],
1470        };
1471        if with_position {
1472            base_columns.push(POSITION_COL);
1473        }
1474        base_columns
1475    }
1476}
1477
1478/// New type just to allow Positions implement DeepSizeOf so it can be put
1479/// in the cache.
1480#[derive(Clone)]
1481pub struct Positions(ListArray);
1482
1483impl DeepSizeOf for Positions {
1484    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1485        self.0.get_buffer_memory_size()
1486    }
1487}
1488
1489// Cache key implementations for type-safe cache access
1490#[derive(Debug, Clone)]
1491pub struct PostingListKey {
1492    pub token_id: u32,
1493}
1494
1495impl CacheKey for PostingListKey {
1496    type ValueType = PostingList;
1497
1498    fn key(&self) -> std::borrow::Cow<'_, str> {
1499        format!("postings-{}", self.token_id).into()
1500    }
1501}
1502
1503#[derive(Debug, Clone)]
1504pub struct PositionKey {
1505    pub token_id: u32,
1506}
1507
1508impl CacheKey for PositionKey {
1509    type ValueType = Positions;
1510
1511    fn key(&self) -> std::borrow::Cow<'_, str> {
1512        format!("positions-{}", self.token_id).into()
1513    }
1514}
1515
1516#[derive(Debug, Clone, DeepSizeOf)]
1517pub enum PostingList {
1518    Plain(PlainPostingList),
1519    Compressed(CompressedPostingList),
1520}
1521
1522impl PostingList {
1523    pub fn from_batch(
1524        batch: &RecordBatch,
1525        max_score: Option<f32>,
1526        length: Option<u32>,
1527    ) -> Result<Self> {
1528        match batch.column_by_name(POSTING_COL) {
1529            Some(_) => {
1530                debug_assert!(max_score.is_some() && length.is_some());
1531                let posting =
1532                    CompressedPostingList::from_batch(batch, max_score.unwrap(), length.unwrap());
1533                Ok(Self::Compressed(posting))
1534            }
1535            None => {
1536                let posting = PlainPostingList::from_batch(batch, max_score);
1537                Ok(Self::Plain(posting))
1538            }
1539        }
1540    }
1541
1542    pub fn iter(&self) -> PostingListIterator<'_> {
1543        PostingListIterator::new(self)
1544    }
1545
1546    pub fn has_position(&self) -> bool {
1547        match self {
1548            Self::Plain(posting) => posting.positions.is_some(),
1549            Self::Compressed(posting) => posting.positions.is_some(),
1550        }
1551    }
1552
1553    pub fn set_positions(&mut self, positions: ListArray) {
1554        match self {
1555            Self::Plain(posting) => posting.positions = Some(positions),
1556            Self::Compressed(posting) => {
1557                posting.positions = Some(positions.value(0).as_list::<i32>().clone());
1558            }
1559        }
1560    }
1561
1562    pub fn max_score(&self) -> Option<f32> {
1563        match self {
1564            Self::Plain(posting) => posting.max_score,
1565            Self::Compressed(posting) => Some(posting.max_score),
1566        }
1567    }
1568
1569    pub fn len(&self) -> usize {
1570        match self {
1571            Self::Plain(posting) => posting.len(),
1572            Self::Compressed(posting) => posting.length as usize,
1573        }
1574    }
1575
1576    pub fn is_empty(&self) -> bool {
1577        self.len() == 0
1578    }
1579
1580    pub fn into_builder(self, docs: &DocSet) -> PostingListBuilder {
1581        let mut builder = PostingListBuilder::new(self.has_position());
1582        match self {
1583            // legacy format
1584            Self::Plain(posting) => {
1585                // convert the posting list to the new format:
1586                // 1. map row ids to doc ids
1587                // 2. sort the posting list by doc ids
1588                struct Item {
1589                    doc_id: u32,
1590                    positions: PositionRecorder,
1591                }
1592                let doc_ids = docs
1593                    .row_ids
1594                    .iter()
1595                    .enumerate()
1596                    .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
1597                    .collect::<HashMap<_, _>>();
1598                let mut items = Vec::with_capacity(posting.len());
1599                for (row_id, freq, positions) in posting.iter() {
1600                    let freq = freq as u32;
1601                    let positions = match positions {
1602                        Some(positions) => {
1603                            PositionRecorder::Position(positions.collect::<Vec<_>>())
1604                        }
1605                        None => PositionRecorder::Count(freq),
1606                    };
1607                    items.push(Item {
1608                        doc_id: doc_ids[&row_id],
1609                        positions,
1610                    });
1611                }
1612                items.sort_unstable_by_key(|item| item.doc_id);
1613                for item in items {
1614                    builder.add(item.doc_id, item.positions);
1615                }
1616            }
1617            Self::Compressed(posting) => {
1618                posting.iter().for_each(|(doc_id, freq, positions)| {
1619                    let positions = match positions {
1620                        Some(positions) => {
1621                            PositionRecorder::Position(positions.collect::<Vec<_>>())
1622                        }
1623                        None => PositionRecorder::Count(freq),
1624                    };
1625                    builder.add(doc_id, positions);
1626                });
1627            }
1628        }
1629        builder
1630    }
1631}
1632
1633#[derive(Debug, PartialEq, Clone)]
1634pub struct PlainPostingList {
1635    pub row_ids: ScalarBuffer<u64>,
1636    pub frequencies: ScalarBuffer<f32>,
1637    pub max_score: Option<f32>,
1638    pub positions: Option<ListArray>, // List of Int32
1639}
1640
1641impl DeepSizeOf for PlainPostingList {
1642    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1643        self.row_ids.len() * std::mem::size_of::<u64>()
1644            + self.frequencies.len() * std::mem::size_of::<u32>()
1645            + self
1646                .positions
1647                .as_ref()
1648                .map(|positions| positions.get_buffer_memory_size())
1649                .unwrap_or(0)
1650    }
1651}
1652
1653impl PlainPostingList {
1654    pub fn new(
1655        row_ids: ScalarBuffer<u64>,
1656        frequencies: ScalarBuffer<f32>,
1657        max_score: Option<f32>,
1658        positions: Option<ListArray>,
1659    ) -> Self {
1660        Self {
1661            row_ids,
1662            frequencies,
1663            max_score,
1664            positions,
1665        }
1666    }
1667
1668    pub fn from_batch(batch: &RecordBatch, max_score: Option<f32>) -> Self {
1669        let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().values().clone();
1670        let frequencies = batch[FREQUENCY_COL]
1671            .as_primitive::<Float32Type>()
1672            .values()
1673            .clone();
1674        let positions = batch
1675            .column_by_name(POSITION_COL)
1676            .map(|col| col.as_list::<i32>().clone());
1677
1678        Self::new(row_ids, frequencies, max_score, positions)
1679    }
1680
1681    pub fn len(&self) -> usize {
1682        self.row_ids.len()
1683    }
1684
1685    pub fn is_empty(&self) -> bool {
1686        self.len() == 0
1687    }
1688
1689    pub fn iter(&self) -> PlainPostingListIterator<'_> {
1690        Box::new(
1691            self.row_ids
1692                .iter()
1693                .zip(self.frequencies.iter())
1694                .enumerate()
1695                .map(|(idx, (doc_id, freq))| {
1696                    (
1697                        *doc_id,
1698                        *freq,
1699                        self.positions.as_ref().map(|p| {
1700                            let start = p.value_offsets()[idx] as usize;
1701                            let end = p.value_offsets()[idx + 1] as usize;
1702                            Box::new(
1703                                p.values().as_primitive::<Int32Type>().values()[start..end]
1704                                    .iter()
1705                                    .map(|pos| *pos as u32),
1706                            ) as _
1707                        }),
1708                    )
1709                }),
1710        )
1711    }
1712
1713    #[inline]
1714    pub fn doc(&self, i: usize) -> LocatedDocInfo {
1715        LocatedDocInfo::new(self.row_ids[i], self.frequencies[i])
1716    }
1717
1718    pub fn positions(&self, index: usize) -> Option<Arc<dyn Array>> {
1719        self.positions
1720            .as_ref()
1721            .map(|positions| positions.value(index))
1722    }
1723
1724    pub fn max_score(&self) -> Option<f32> {
1725        self.max_score
1726    }
1727
1728    pub fn row_id(&self, i: usize) -> u64 {
1729        self.row_ids[i]
1730    }
1731}
1732
1733#[derive(Debug, PartialEq, Clone)]
1734pub struct CompressedPostingList {
1735    pub max_score: f32,
1736    pub length: u32,
1737    // each binary is a block of compressed data
1738    // that contains `BLOCK_SIZE` doc ids and then `BLOCK_SIZE` frequencies
1739    pub blocks: LargeBinaryArray,
1740    pub positions: Option<ListArray>,
1741}
1742
1743impl DeepSizeOf for CompressedPostingList {
1744    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1745        self.blocks.get_buffer_memory_size()
1746            + self
1747                .positions
1748                .as_ref()
1749                .map(|positions| positions.get_buffer_memory_size())
1750                .unwrap_or(0)
1751    }
1752}
1753
1754impl CompressedPostingList {
1755    pub fn new(
1756        blocks: LargeBinaryArray,
1757        max_score: f32,
1758        length: u32,
1759        positions: Option<ListArray>,
1760    ) -> Self {
1761        Self {
1762            max_score,
1763            length,
1764            blocks,
1765            positions,
1766        }
1767    }
1768
1769    pub fn from_batch(batch: &RecordBatch, max_score: f32, length: u32) -> Self {
1770        debug_assert_eq!(batch.num_rows(), 1);
1771        let blocks = batch[POSTING_COL]
1772            .as_list::<i32>()
1773            .value(0)
1774            .as_binary::<i64>()
1775            .clone();
1776        let positions = batch
1777            .column_by_name(POSITION_COL)
1778            .map(|col| col.as_list::<i32>().value(0).as_list::<i32>().clone());
1779
1780        Self {
1781            max_score,
1782            length,
1783            blocks,
1784            positions,
1785        }
1786    }
1787
1788    pub fn iter(&self) -> CompressedPostingListIterator {
1789        CompressedPostingListIterator::new(
1790            self.length as usize,
1791            self.blocks.clone(),
1792            self.positions.clone(),
1793        )
1794    }
1795
1796    pub fn block_max_score(&self, block_idx: usize) -> f32 {
1797        let block = self.blocks.value(block_idx);
1798        block[0..4].try_into().map(f32::from_le_bytes).unwrap()
1799    }
1800
1801    pub fn block_least_doc_id(&self, block_idx: usize) -> u32 {
1802        let block = self.blocks.value(block_idx);
1803        block[4..8].try_into().map(u32::from_le_bytes).unwrap()
1804    }
1805}
1806
1807#[derive(Debug)]
1808pub struct PostingListBuilder {
1809    pub doc_ids: ExpLinkedList<u32>,
1810    pub frequencies: ExpLinkedList<u32>,
1811    pub positions: Option<PositionBuilder>,
1812}
1813
1814impl PostingListBuilder {
1815    pub fn size(&self) -> u64 {
1816        (std::mem::size_of::<u32>() * self.doc_ids.len()
1817            + std::mem::size_of::<u32>() * self.frequencies.len()
1818            + self
1819                .positions
1820                .as_ref()
1821                .map(|positions| positions.size())
1822                .unwrap_or(0)) as u64
1823    }
1824
1825    pub fn has_positions(&self) -> bool {
1826        self.positions.is_some()
1827    }
1828
1829    pub fn new(with_position: bool) -> Self {
1830        Self {
1831            doc_ids: ExpLinkedList::new().with_capacity_limit(128),
1832            frequencies: ExpLinkedList::new().with_capacity_limit(128),
1833            positions: with_position.then(PositionBuilder::new),
1834        }
1835    }
1836
1837    pub fn len(&self) -> usize {
1838        self.doc_ids.len()
1839    }
1840
1841    pub fn is_empty(&self) -> bool {
1842        self.len() == 0
1843    }
1844
1845    pub fn iter(&self) -> impl Iterator<Item = (&u32, &u32, Option<&[u32]>)> {
1846        self.doc_ids
1847            .iter()
1848            .zip(self.frequencies.iter())
1849            .enumerate()
1850            .map(|(idx, (doc_id, freq))| {
1851                let positions = self.positions.as_ref().map(|positions| positions.get(idx));
1852                (doc_id, freq, positions)
1853            })
1854    }
1855
1856    pub fn add(&mut self, doc_id: u32, term_positions: PositionRecorder) {
1857        self.doc_ids.push(doc_id);
1858        self.frequencies.push(term_positions.len());
1859        if let Some(positions) = self.positions.as_mut() {
1860            positions.push(term_positions.into_vec());
1861        }
1862    }
1863
1864    // assume the posting list is sorted by doc id
1865    pub fn to_batch(self, block_max_scores: Vec<f32>) -> Result<RecordBatch> {
1866        let length = self.len();
1867        let max_score = block_max_scores.iter().copied().fold(f32::MIN, f32::max);
1868
1869        let schema = inverted_list_schema(self.has_positions());
1870        let compressed = compress_posting_list(
1871            self.doc_ids.len(),
1872            self.doc_ids.iter(),
1873            self.frequencies.iter(),
1874            block_max_scores.into_iter(),
1875        )?;
1876        let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, compressed.len() as i32]));
1877        let mut columns = vec![
1878            Arc::new(ListArray::try_new(
1879                Arc::new(Field::new("item", datatypes::DataType::LargeBinary, true)),
1880                offsets,
1881                Arc::new(compressed),
1882                None,
1883            )?) as ArrayRef,
1884            Arc::new(Float32Array::from_iter_values(std::iter::once(max_score))) as ArrayRef,
1885            Arc::new(UInt32Array::from_iter_values(std::iter::once(
1886                self.len() as u32
1887            ))) as ArrayRef,
1888        ];
1889
1890        if let Some(positions) = self.positions.as_ref() {
1891            let mut position_builder = ListBuilder::new(ListBuilder::with_capacity(
1892                LargeBinaryBuilder::new(),
1893                length,
1894            ));
1895            for index in 0..length {
1896                let positions_in_doc = positions.get(index);
1897                let compressed = compress_positions(positions_in_doc)?;
1898                let inner_builder = position_builder.values();
1899                inner_builder.append_value(compressed.into_iter());
1900            }
1901            position_builder.append(true);
1902            let position_col = position_builder.finish();
1903            columns.push(Arc::new(position_col));
1904        }
1905
1906        let batch = RecordBatch::try_new(schema, columns)?;
1907        Ok(batch)
1908    }
1909
1910    pub fn remap(&mut self, removed: &[u32]) {
1911        let mut cursor = 0;
1912        let mut new_doc_ids = ExpLinkedList::with_capacity(self.len());
1913        let mut new_frequencies = ExpLinkedList::with_capacity(self.len());
1914        let mut new_positions = self.positions.as_mut().map(|_| PositionBuilder::new());
1915        for (&doc_id, &freq, positions) in self.iter() {
1916            while cursor < removed.len() && removed[cursor] < doc_id {
1917                cursor += 1;
1918            }
1919            if cursor < removed.len() && removed[cursor] == doc_id {
1920                // this doc is removed
1921                continue;
1922            }
1923            // there are cursor removed docs before this doc
1924            // so we need to shift the doc id
1925            new_doc_ids.push(doc_id - cursor as u32);
1926            new_frequencies.push(freq);
1927            if let Some(new_positions) = new_positions.as_mut() {
1928                new_positions.push(positions.unwrap().to_vec());
1929            }
1930        }
1931
1932        self.doc_ids = new_doc_ids;
1933        self.frequencies = new_frequencies;
1934        self.positions = new_positions;
1935    }
1936}
1937
1938#[derive(Debug, Clone, DeepSizeOf)]
1939pub struct PositionBuilder {
1940    positions: Vec<u32>,
1941    offsets: Vec<i32>,
1942}
1943
1944impl Default for PositionBuilder {
1945    fn default() -> Self {
1946        Self::new()
1947    }
1948}
1949
1950impl PositionBuilder {
1951    pub fn new() -> Self {
1952        Self {
1953            positions: Vec::new(),
1954            offsets: vec![0],
1955        }
1956    }
1957
1958    pub fn size(&self) -> usize {
1959        std::mem::size_of::<u32>() * self.positions.len()
1960            + std::mem::size_of::<i32>() * self.offsets.len()
1961    }
1962
1963    pub fn total_len(&self) -> usize {
1964        self.positions.len()
1965    }
1966
1967    pub fn push(&mut self, positions: Vec<u32>) {
1968        self.positions.extend(positions);
1969        self.offsets.push(self.positions.len() as i32);
1970    }
1971
1972    pub fn get(&self, i: usize) -> &[u32] {
1973        let start = self.offsets[i] as usize;
1974        let end = self.offsets[i + 1] as usize;
1975        &self.positions[start..end]
1976    }
1977}
1978
1979impl From<Vec<Vec<u32>>> for PositionBuilder {
1980    fn from(positions: Vec<Vec<u32>>) -> Self {
1981        let mut builder = Self::new();
1982        builder.offsets.reserve(positions.len());
1983        for pos in positions {
1984            builder.push(pos);
1985        }
1986        builder
1987    }
1988}
1989
1990#[derive(Debug, Clone, DeepSizeOf, Copy)]
1991pub enum DocInfo {
1992    Located(LocatedDocInfo),
1993    Raw(RawDocInfo),
1994}
1995
1996impl DocInfo {
1997    pub fn doc_id(&self) -> u64 {
1998        match self {
1999            Self::Raw(info) => info.doc_id as u64,
2000            Self::Located(info) => info.row_id,
2001        }
2002    }
2003
2004    pub fn frequency(&self) -> u32 {
2005        match self {
2006            Self::Raw(info) => info.frequency,
2007            Self::Located(info) => info.frequency as u32,
2008        }
2009    }
2010}
2011
2012impl Eq for DocInfo {}
2013
2014impl PartialEq for DocInfo {
2015    fn eq(&self, other: &Self) -> bool {
2016        self.doc_id() == other.doc_id()
2017    }
2018}
2019
2020impl PartialOrd for DocInfo {
2021    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2022        Some(self.cmp(other))
2023    }
2024}
2025
2026impl Ord for DocInfo {
2027    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2028        self.doc_id().cmp(&other.doc_id())
2029    }
2030}
2031
2032#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
2033pub struct LocatedDocInfo {
2034    pub row_id: u64,
2035    pub frequency: f32,
2036}
2037
2038impl LocatedDocInfo {
2039    pub fn new(row_id: u64, frequency: f32) -> Self {
2040        Self { row_id, frequency }
2041    }
2042}
2043
2044impl Eq for LocatedDocInfo {}
2045
2046impl PartialEq for LocatedDocInfo {
2047    fn eq(&self, other: &Self) -> bool {
2048        self.row_id == other.row_id
2049    }
2050}
2051
2052impl PartialOrd for LocatedDocInfo {
2053    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2054        Some(self.cmp(other))
2055    }
2056}
2057
2058impl Ord for LocatedDocInfo {
2059    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2060        self.row_id.cmp(&other.row_id)
2061    }
2062}
2063
2064#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
2065pub struct RawDocInfo {
2066    pub doc_id: u32,
2067    pub frequency: u32,
2068}
2069
2070impl RawDocInfo {
2071    pub fn new(doc_id: u32, frequency: u32) -> Self {
2072        Self { doc_id, frequency }
2073    }
2074}
2075
2076impl Eq for RawDocInfo {}
2077
2078impl PartialEq for RawDocInfo {
2079    fn eq(&self, other: &Self) -> bool {
2080        self.doc_id == other.doc_id
2081    }
2082}
2083
2084impl PartialOrd for RawDocInfo {
2085    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2086        Some(self.cmp(other))
2087    }
2088}
2089
2090impl Ord for RawDocInfo {
2091    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2092        self.doc_id.cmp(&other.doc_id)
2093    }
2094}
2095
2096// DocSet is a mapping from row ids to the number of tokens in the document
2097// It's used to sort the documents by the bm25 score
2098#[derive(Debug, Clone, Default, DeepSizeOf)]
2099pub struct DocSet {
2100    row_ids: Vec<u64>,
2101    num_tokens: Vec<u32>,
2102    // (row_id, doc_id) pairs sorted by row_id
2103    inv: Vec<(u64, u32)>,
2104
2105    total_tokens: u64,
2106}
2107
2108impl DocSet {
2109    #[inline]
2110    pub fn len(&self) -> usize {
2111        self.row_ids.len()
2112    }
2113
2114    pub fn is_empty(&self) -> bool {
2115        self.len() == 0
2116    }
2117
2118    pub fn iter(&self) -> impl Iterator<Item = (&u64, &u32)> {
2119        self.row_ids.iter().zip(self.num_tokens.iter())
2120    }
2121
2122    pub fn row_id(&self, doc_id: u32) -> u64 {
2123        self.row_ids[doc_id as usize]
2124    }
2125
2126    pub fn doc_id(&self, row_id: u64) -> Option<u64> {
2127        if self.inv.is_empty() {
2128            // in legacy format, the row id is doc id
2129            match self.row_ids.binary_search(&row_id) {
2130                Ok(_) => Some(row_id),
2131                Err(_) => None,
2132            }
2133        } else {
2134            match self.inv.binary_search_by_key(&row_id, |x| x.0) {
2135                Ok(idx) => Some(self.inv[idx].1 as u64),
2136                Err(_) => None,
2137            }
2138        }
2139    }
2140    pub fn total_tokens_num(&self) -> u64 {
2141        self.total_tokens
2142    }
2143
2144    #[inline]
2145    pub fn average_length(&self) -> f32 {
2146        self.total_tokens as f32 / self.len() as f32
2147    }
2148
2149    pub fn calculate_block_max_scores<'a>(
2150        &self,
2151        doc_ids: impl Iterator<Item = &'a u32>,
2152        freqs: impl Iterator<Item = &'a u32>,
2153    ) -> Vec<f32> {
2154        let avgdl = self.average_length();
2155        let length = doc_ids.size_hint().0;
2156        let mut block_max_scores = Vec::with_capacity(length);
2157        let mut max_score = f32::MIN;
2158        for (i, (doc_id, freq)) in doc_ids.zip(freqs).enumerate() {
2159            let doc_norm = K1 * (1.0 - B + B * self.num_tokens(*doc_id) as f32 / avgdl);
2160            let freq = *freq as f32;
2161            let score = freq / (freq + doc_norm);
2162            if score > max_score {
2163                max_score = score;
2164            }
2165            if (i + 1) % BLOCK_SIZE == 0 {
2166                max_score *= idf(length, self.len()) * (K1 + 1.0);
2167                block_max_scores.push(max_score);
2168                max_score = f32::MIN;
2169            }
2170        }
2171        if length % BLOCK_SIZE > 0 {
2172            max_score *= idf(length, self.len()) * (K1 + 1.0);
2173            block_max_scores.push(max_score);
2174        }
2175        block_max_scores
2176    }
2177
2178    pub fn to_batch(&self) -> Result<RecordBatch> {
2179        let row_id_col = UInt64Array::from_iter_values(self.row_ids.iter().cloned());
2180        let num_tokens_col = UInt32Array::from_iter_values(self.num_tokens.iter().cloned());
2181
2182        let schema = arrow_schema::Schema::new(vec![
2183            arrow_schema::Field::new(ROW_ID, DataType::UInt64, false),
2184            arrow_schema::Field::new(NUM_TOKEN_COL, DataType::UInt32, false),
2185        ]);
2186
2187        let batch = RecordBatch::try_new(
2188            Arc::new(schema),
2189            vec![
2190                Arc::new(row_id_col) as ArrayRef,
2191                Arc::new(num_tokens_col) as ArrayRef,
2192            ],
2193        )?;
2194        Ok(batch)
2195    }
2196
2197    pub async fn load(
2198        reader: Arc<dyn IndexReader>,
2199        is_legacy: bool,
2200        frag_reuse_index: Option<Arc<FragReuseIndex>>,
2201    ) -> Result<Self> {
2202        let batch = reader.read_range(0..reader.num_rows(), None).await?;
2203        let row_id_col = batch[ROW_ID].as_primitive::<datatypes::UInt64Type>();
2204        let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::<datatypes::UInt32Type>();
2205
2206        // for legacy format, the row id is doc id; sorting keeps binary search viable
2207        if is_legacy {
2208            let (row_ids, num_tokens): (Vec<_>, Vec<_>) = row_id_col
2209                .values()
2210                .iter()
2211                .filter_map(|id| {
2212                    if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
2213                        frag_reuse_index_ref.remap_row_id(*id)
2214                    } else {
2215                        Some(*id)
2216                    }
2217                })
2218                .zip(num_tokens_col.values().iter())
2219                .sorted_unstable_by_key(|x| x.0)
2220                .unzip();
2221
2222            let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2223            return Ok(Self {
2224                row_ids,
2225                num_tokens,
2226                inv: Vec::new(),
2227                total_tokens,
2228            });
2229        }
2230
2231        // if frag reuse happened, we'll need to remap the row_ids. And after row_ids been
2232        // remapped, we'll need resort to make sure binary_search works.
2233        if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
2234            let mut row_ids = Vec::with_capacity(row_id_col.len());
2235            let mut num_tokens = Vec::with_capacity(num_tokens_col.len());
2236            for (row_id, num_token) in row_id_col.values().iter().zip(num_tokens_col.values()) {
2237                if let Some(new_row_id) = frag_reuse_index_ref.remap_row_id(*row_id) {
2238                    row_ids.push(new_row_id);
2239                    num_tokens.push(*num_token);
2240                }
2241            }
2242
2243            let mut inv: Vec<(u64, u32)> = row_ids
2244                .iter()
2245                .enumerate()
2246                .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
2247                .collect();
2248            inv.sort_unstable_by_key(|entry| entry.0);
2249
2250            let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2251            return Ok(Self {
2252                row_ids,
2253                num_tokens,
2254                inv,
2255                total_tokens,
2256            });
2257        }
2258
2259        let row_ids = row_id_col.values().to_vec();
2260        let num_tokens = num_tokens_col.values().to_vec();
2261        let mut inv: Vec<(u64, u32)> = row_ids
2262            .iter()
2263            .enumerate()
2264            .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
2265            .collect();
2266        if !row_ids.is_sorted() {
2267            inv.sort_unstable_by_key(|entry| entry.0);
2268        }
2269        let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2270        Ok(Self {
2271            row_ids,
2272            num_tokens,
2273            inv,
2274            total_tokens,
2275        })
2276    }
2277
2278    // remap the row ids to the new row ids
2279    // returns the removed doc ids
2280    pub fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Vec<u32> {
2281        let mut removed = Vec::new();
2282        let len = self.len();
2283        let row_ids = std::mem::replace(&mut self.row_ids, Vec::with_capacity(len));
2284        let num_tokens = std::mem::replace(&mut self.num_tokens, Vec::with_capacity(len));
2285        for (doc_id, (row_id, num_token)) in std::iter::zip(row_ids, num_tokens).enumerate() {
2286            match mapping.get(&row_id) {
2287                Some(Some(new_row_id)) => {
2288                    self.row_ids.push(*new_row_id);
2289                    self.num_tokens.push(num_token);
2290                }
2291                Some(None) => {
2292                    removed.push(doc_id as u32);
2293                }
2294                None => {
2295                    self.row_ids.push(row_id);
2296                    self.num_tokens.push(num_token);
2297                }
2298            }
2299        }
2300        removed
2301    }
2302
2303    #[inline]
2304    pub fn num_tokens(&self, doc_id: u32) -> u32 {
2305        self.num_tokens[doc_id as usize]
2306    }
2307
2308    // this can be used only if it's a legacy format,
2309    // which store the sorted row ids so that we can use binary search
2310    #[inline]
2311    pub fn num_tokens_by_row_id(&self, row_id: u64) -> u32 {
2312        self.row_ids
2313            .binary_search(&row_id)
2314            .map(|idx| self.num_tokens[idx])
2315            .unwrap_or(0)
2316    }
2317
2318    // append a document to the doc set
2319    // returns the doc_id (the number of documents before appending)
2320    pub fn append(&mut self, row_id: u64, num_tokens: u32) -> u32 {
2321        self.row_ids.push(row_id);
2322        self.num_tokens.push(num_tokens);
2323        self.total_tokens += num_tokens as u64;
2324        self.row_ids.len() as u32 - 1
2325    }
2326}
2327
2328pub fn flat_full_text_search(
2329    batches: &[&RecordBatch],
2330    doc_col: &str,
2331    query: &str,
2332    tokenizer: Option<Box<dyn LanceTokenizer>>,
2333) -> Result<Vec<u64>> {
2334    if batches.is_empty() {
2335        return Ok(vec![]);
2336    }
2337
2338    if is_phrase_query(query) {
2339        return Err(Error::invalid_input(
2340            "phrase query is not supported for flat full text search, try using FTS index",
2341            location!(),
2342        ));
2343    }
2344
2345    match batches[0][doc_col].data_type() {
2346        DataType::Utf8 => do_flat_full_text_search::<i32>(batches, doc_col, query, tokenizer),
2347        DataType::LargeUtf8 => do_flat_full_text_search::<i64>(batches, doc_col, query, tokenizer),
2348        data_type => Err(Error::invalid_input(
2349            format!("unsupported data type {} for inverted index", data_type),
2350            location!(),
2351        )),
2352    }
2353}
2354
2355fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
2356    batches: &[&RecordBatch],
2357    doc_col: &str,
2358    query: &str,
2359    tokenizer: Option<Box<dyn LanceTokenizer>>,
2360) -> Result<Vec<u64>> {
2361    let mut results = Vec::new();
2362    let mut tokenizer =
2363        tokenizer.unwrap_or_else(|| InvertedIndexParams::default().build().unwrap());
2364    let query_tokens = collect_query_tokens(query, &mut tokenizer, None);
2365
2366    for batch in batches {
2367        let row_id_array = batch[ROW_ID].as_primitive::<UInt64Type>();
2368        let doc_array = batch[doc_col].as_string::<Offset>();
2369        for i in 0..row_id_array.len() {
2370            let doc = doc_array.value(i);
2371            let doc_tokens = collect_doc_tokens(doc, &mut tokenizer, Some(&query_tokens));
2372            if !doc_tokens.is_empty() {
2373                results.push(row_id_array.value(i));
2374                assert!(doc.contains(query));
2375            }
2376        }
2377    }
2378
2379    Ok(results)
2380}
2381
2382#[allow(clippy::too_many_arguments)]
2383pub fn flat_bm25_search(
2384    batch: RecordBatch,
2385    doc_col: &str,
2386    query_tokens: &Tokens,
2387    tokenizer: &mut Box<dyn LanceTokenizer>,
2388    scorer: &mut MemBM25Scorer,
2389) -> std::result::Result<RecordBatch, DataFusionError> {
2390    let doc_iter = iter_str_array(&batch[doc_col]);
2391    let mut scores = Vec::with_capacity(batch.num_rows());
2392    for doc in doc_iter {
2393        let Some(doc) = doc else {
2394            scores.push(0.0);
2395            continue;
2396        };
2397
2398        let doc_tokens = collect_doc_tokens(doc, tokenizer, None);
2399        scorer.update(&doc_tokens);
2400        let doc_tokens = doc_tokens
2401            .into_iter()
2402            .filter(|t| query_tokens.contains(t))
2403            .collect::<Vec<_>>();
2404
2405        let doc_norm = K1 * (1.0 - B + B * doc_tokens.len() as f32 / scorer.avg_doc_length());
2406        let mut doc_token_count = HashMap::new();
2407        for token in doc_tokens {
2408            doc_token_count
2409                .entry(token)
2410                .and_modify(|count| *count += 1)
2411                .or_insert(1);
2412        }
2413        let mut score = 0.0;
2414        for token in query_tokens {
2415            let freq = doc_token_count.get(token).copied().unwrap_or_default() as f32;
2416
2417            let idf = idf(scorer.num_docs_containing_token(token), scorer.num_docs());
2418            score += idf * (freq * (K1 + 1.0) / (freq + doc_norm));
2419        }
2420        scores.push(score);
2421    }
2422
2423    let score_col = Arc::new(Float32Array::from(scores)) as ArrayRef;
2424    let batch = batch
2425        .try_with_column(SCORE_FIELD.clone(), score_col)?
2426        .project_by_schema(&FTS_SCHEMA)?; // the scan node would probably scan some extra columns for prefilter, drop them here
2427    Ok(batch)
2428}
2429
2430pub fn flat_bm25_search_stream(
2431    input: SendableRecordBatchStream,
2432    doc_col: String,
2433    query: String,
2434    index: &Option<InvertedIndex>,
2435) -> SendableRecordBatchStream {
2436    let mut tokenizer = match index {
2437        Some(index) => index.tokenizer(),
2438        None => Box::new(TextTokenizer::new(
2439            tantivy::tokenizer::TextAnalyzer::builder(
2440                tantivy::tokenizer::SimpleTokenizer::default(),
2441            )
2442            .build(),
2443        )),
2444    };
2445    let tokens = collect_query_tokens(&query, &mut tokenizer, None);
2446
2447    let mut bm25_scorer = match index {
2448        Some(index) => {
2449            let index_bm25_scorer =
2450                IndexBM25Scorer::new(index.partitions.iter().map(|p| p.as_ref()));
2451            if index_bm25_scorer.num_docs() == 0 {
2452                MemBM25Scorer::new(0, 0, HashMap::new())
2453            } else {
2454                let mut token_docs = HashMap::with_capacity(tokens.len());
2455                for token in &tokens {
2456                    let token_nq = index_bm25_scorer.num_docs_containing_token(token).max(1);
2457                    token_docs.insert(token.clone(), token_nq);
2458                }
2459                MemBM25Scorer::new(
2460                    index_bm25_scorer.avg_doc_length() as u64 * index_bm25_scorer.num_docs() as u64,
2461                    index_bm25_scorer.num_docs(),
2462                    token_docs,
2463                )
2464            }
2465        }
2466        None => MemBM25Scorer::new(0, 0, HashMap::new()),
2467    };
2468
2469    let stream = input.map(move |batch| {
2470        let batch = batch?;
2471
2472        let batch = flat_bm25_search(batch, &doc_col, &tokens, &mut tokenizer, &mut bm25_scorer)?;
2473
2474        // filter out rows with score 0
2475        let score_col = batch[SCORE_COL].as_primitive::<Float32Type>();
2476        let mask = score_col
2477            .iter()
2478            .map(|score| score.is_some_and(|score| score > 0.0))
2479            .collect::<Vec<_>>();
2480        let mask = BooleanArray::from(mask);
2481        let batch = arrow::compute::filter_record_batch(&batch, &mask)?;
2482        debug_assert!(batch[ROW_ID].null_count() == 0, "flat FTS produces nulls");
2483        Ok(batch)
2484    });
2485
2486    Box::pin(RecordBatchStreamAdapter::new(FTS_SCHEMA.clone(), stream)) as SendableRecordBatchStream
2487}
2488
2489pub fn is_phrase_query(query: &str) -> bool {
2490    query.starts_with('\"') && query.ends_with('\"')
2491}
2492
2493#[cfg(test)]
2494mod tests {
2495    use crate::scalar::inverted::lance_tokenizer::DocType;
2496    use lance_core::cache::LanceCache;
2497    use lance_core::utils::tempfile::TempObjDir;
2498    use lance_io::object_store::ObjectStore;
2499
2500    use crate::metrics::NoOpMetricsCollector;
2501    use crate::prefilter::NoFilter;
2502    use crate::scalar::inverted::builder::{InnerBuilder, PositionRecorder};
2503    use crate::scalar::inverted::encoding::decompress_posting_list;
2504    use crate::scalar::inverted::query::{FtsSearchParams, Operator};
2505    use crate::scalar::lance_format::LanceIndexStore;
2506
2507    use super::*;
2508
2509    #[tokio::test]
2510    async fn test_posting_builder_remap() {
2511        let mut builder = PostingListBuilder::new(false);
2512        let n = BLOCK_SIZE + 3;
2513        for i in 0..n {
2514            builder.add(i as u32, PositionRecorder::Count(1));
2515        }
2516        let removed = vec![5, 7];
2517        builder.remap(&removed);
2518
2519        let mut expected = PostingListBuilder::new(false);
2520        for i in 0..n - removed.len() {
2521            expected.add(i as u32, PositionRecorder::Count(1));
2522        }
2523        assert_eq!(builder.doc_ids, expected.doc_ids);
2524        assert_eq!(builder.frequencies, expected.frequencies);
2525
2526        // BLOCK_SIZE + 3 elements should be reduced to BLOCK_SIZE + 1,
2527        // there are still 2 blocks.
2528        let batch = builder.to_batch(vec![1.0, 2.0]).unwrap();
2529        let (doc_ids, freqs) = decompress_posting_list(
2530            (n - removed.len()) as u32,
2531            batch[POSTING_COL]
2532                .as_list::<i32>()
2533                .value(0)
2534                .as_binary::<i64>(),
2535        )
2536        .unwrap();
2537        assert!(doc_ids
2538            .iter()
2539            .zip(expected.doc_ids.iter())
2540            .all(|(a, b)| a == b));
2541        assert!(freqs
2542            .iter()
2543            .zip(expected.frequencies.iter())
2544            .all(|(a, b)| a == b));
2545    }
2546
2547    #[tokio::test]
2548    async fn test_remap_to_empty_posting_list() {
2549        let tmpdir = TempObjDir::default();
2550        let store = Arc::new(LanceIndexStore::new(
2551            ObjectStore::local().into(),
2552            tmpdir.clone(),
2553            Arc::new(LanceCache::no_cache()),
2554        ));
2555
2556        let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default());
2557
2558        // index of docs:
2559        // 0: lance
2560        // 1: lake lake
2561        // 2: lake lake lake
2562        builder.tokens.add("lance".to_owned());
2563        builder.tokens.add("lake".to_owned());
2564        builder.posting_lists.push(PostingListBuilder::new(false));
2565        builder.posting_lists.push(PostingListBuilder::new(false));
2566        builder.posting_lists[0].add(0, PositionRecorder::Count(1));
2567        builder.posting_lists[1].add(1, PositionRecorder::Count(2));
2568        builder.posting_lists[1].add(2, PositionRecorder::Count(3));
2569        builder.docs.append(0, 1);
2570        builder.docs.append(1, 1);
2571        builder.docs.append(2, 1);
2572        builder.write(store.as_ref()).await.unwrap();
2573
2574        let index = InvertedPartition::load(
2575            store.clone(),
2576            0,
2577            None,
2578            &LanceCache::no_cache(),
2579            TokenSetFormat::default(),
2580        )
2581        .await
2582        .unwrap();
2583        let mut builder = index.into_builder().await.unwrap();
2584
2585        let mapping = HashMap::from([(0, None), (2, Some(3))]);
2586        builder.remap(&mapping).await.unwrap();
2587
2588        // after remap, the doc 0 is removed, and the doc 2 is updated to 3
2589        assert_eq!(builder.tokens.len(), 1);
2590        assert_eq!(builder.tokens.get("lake"), Some(0));
2591        assert_eq!(builder.posting_lists.len(), 1);
2592        assert_eq!(builder.posting_lists[0].len(), 2);
2593        assert_eq!(builder.docs.len(), 2);
2594        assert_eq!(builder.docs.row_id(0), 1);
2595        assert_eq!(builder.docs.row_id(1), 3);
2596
2597        builder.write(store.as_ref()).await.unwrap();
2598
2599        // remap to delete all docs
2600        let mapping = HashMap::from([(1, None), (3, None)]);
2601        builder.remap(&mapping).await.unwrap();
2602
2603        assert_eq!(builder.tokens.len(), 0);
2604        assert_eq!(builder.posting_lists.len(), 0);
2605        assert_eq!(builder.docs.len(), 0);
2606
2607        builder.write(store.as_ref()).await.unwrap();
2608    }
2609
2610    #[tokio::test]
2611    async fn test_posting_cache_conflict_across_partitions() {
2612        let tmpdir = TempObjDir::default();
2613        let store = Arc::new(LanceIndexStore::new(
2614            ObjectStore::local().into(),
2615            tmpdir.clone(),
2616            Arc::new(LanceCache::no_cache()),
2617        ));
2618
2619        // Create first partition with one token and posting list length 1
2620        let mut builder1 = InnerBuilder::new(0, false, TokenSetFormat::default());
2621        builder1.tokens.add("test".to_owned());
2622        builder1.posting_lists.push(PostingListBuilder::new(false));
2623        builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
2624        builder1.docs.append(100, 1); // row_id=100, num_tokens=1
2625        builder1.write(store.as_ref()).await.unwrap();
2626
2627        // Create second partition with one token and posting list length 4
2628        let mut builder2 = InnerBuilder::new(1, false, TokenSetFormat::default());
2629        builder2.tokens.add("test".to_owned()); // Use same token to test cache prefix fix
2630        builder2.posting_lists.push(PostingListBuilder::new(false));
2631        builder2.posting_lists[0].add(0, PositionRecorder::Count(2));
2632        builder2.posting_lists[0].add(1, PositionRecorder::Count(1));
2633        builder2.posting_lists[0].add(2, PositionRecorder::Count(3));
2634        builder2.posting_lists[0].add(3, PositionRecorder::Count(1));
2635        builder2.docs.append(200, 2); // row_id=200, num_tokens=2
2636        builder2.docs.append(201, 1); // row_id=201, num_tokens=1
2637        builder2.docs.append(202, 3); // row_id=202, num_tokens=3
2638        builder2.docs.append(203, 1); // row_id=203, num_tokens=1
2639        builder2.write(store.as_ref()).await.unwrap();
2640
2641        // Create metadata file with both partitions
2642        let metadata = std::collections::HashMap::from_iter(vec![
2643            (
2644                "partitions".to_owned(),
2645                serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
2646            ),
2647            (
2648                "params".to_owned(),
2649                serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
2650            ),
2651            (
2652                TOKEN_SET_FORMAT_KEY.to_owned(),
2653                TokenSetFormat::default().to_string(),
2654            ),
2655        ]);
2656        let mut writer = store
2657            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
2658            .await
2659            .unwrap();
2660        writer.finish_with_metadata(metadata).await.unwrap();
2661
2662        // Load the inverted index
2663        let cache = Arc::new(LanceCache::with_capacity(4096));
2664        let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
2665            .await
2666            .unwrap();
2667
2668        // Verify the index structure
2669        assert_eq!(index.partitions.len(), 2);
2670        assert_eq!(index.partitions[0].tokens.len(), 1);
2671        assert_eq!(index.partitions[1].tokens.len(), 1);
2672
2673        // Verify the partitions were loaded correctly
2674
2675        // Verify posting list lengths (note: partition order may differ from creation order)
2676        // Verify based on actual loading order
2677        if index.partitions[0].id() == 0 {
2678            // If partition[0] is ID=0, then it should have 1 document
2679            assert_eq!(index.partitions[0].inverted_list.posting_len(0), 1);
2680            assert_eq!(index.partitions[1].inverted_list.posting_len(0), 4);
2681            assert_eq!(index.partitions[0].docs.len(), 1);
2682            assert_eq!(index.partitions[1].docs.len(), 4);
2683        } else {
2684            // If partition[0] is ID=1, then it should have 4 documents
2685            assert_eq!(index.partitions[0].inverted_list.posting_len(0), 4);
2686            assert_eq!(index.partitions[1].inverted_list.posting_len(0), 1);
2687            assert_eq!(index.partitions[0].docs.len(), 4);
2688            assert_eq!(index.partitions[1].docs.len(), 1);
2689        }
2690
2691        // Prewarm the inverted index (this loads posting lists into cache)
2692        index.prewarm().await.unwrap();
2693
2694        let tokens = Arc::new(Tokens::new(vec!["test".to_string()], DocType::Text));
2695        let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
2696        let prefilter = Arc::new(NoFilter);
2697        let metrics = Arc::new(NoOpMetricsCollector);
2698
2699        let (row_ids, scores) = index
2700            .bm25_search(tokens, params, Operator::Or, prefilter, metrics)
2701            .await
2702            .unwrap();
2703
2704        // Verify that we got search results
2705        // Expected to find 5 documents: 1 from first partition, 4 from second partition
2706        assert_eq!(row_ids.len(), 5, "row_ids: {:?}", row_ids);
2707        assert!(!row_ids.is_empty(), "Should find at least some documents");
2708        assert_eq!(row_ids.len(), scores.len());
2709
2710        // All scores should be positive since all documents contain the search token
2711        for &score in &scores {
2712            assert!(score > 0.0, "All scores should be positive");
2713        }
2714
2715        // Check that we got results from both partitions
2716        assert!(
2717            row_ids.contains(&100),
2718            "Should contain row_id from partition 0"
2719        );
2720        assert!(
2721            row_ids.iter().any(|&id| id >= 200),
2722            "Should contain row_id from partition 1"
2723        );
2724    }
2725}