Skip to main content

lance_index/scalar/inverted/
index.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::fmt::{Debug, Display};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
7use std::{
8    cmp::{Reverse, min},
9    collections::BinaryHeap,
10};
11use std::{
12    collections::{HashMap, HashSet},
13    ops::Range,
14    time::Instant,
15};
16
17use crate::metrics::NoOpMetricsCollector;
18use crate::prefilter::NoFilter;
19use crate::scalar::registry::{TrainingCriteria, TrainingOrdering};
20use arrow::array::{FixedSizeListBuilder, Float32Builder};
21use arrow::datatypes::{self, Float32Type, Int32Type, UInt64Type};
22use arrow::{
23    array::{
24        AsArray, LargeBinaryBuilder, ListBuilder, StringBuilder, UInt32Builder, UInt64Builder,
25    },
26    buffer::{Buffer, OffsetBuffer},
27};
28use arrow::{buffer::ScalarBuffer, datatypes::UInt32Type};
29use arrow_array::{
30    Array, ArrayRef, Float32Array, LargeBinaryArray, ListArray, OffsetSizeTrait, RecordBatch,
31    UInt32Array, UInt64Array,
32};
33use arrow_schema::{DataType, Field, Schema, SchemaRef};
34use async_trait::async_trait;
35use datafusion::execution::SendableRecordBatchStream;
36use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
37use deepsize::DeepSizeOf;
38use fst::{Automaton, IntoStreamer, Streamer};
39use futures::{FutureExt, Stream, StreamExt, TryStreamExt, stream};
40use itertools::Itertools;
41use lance_arrow::{RecordBatchExt, iter_str_array};
42use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
43use lance_core::error::{DataFusionResult, LanceOptionExt};
44use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap};
45use lance_core::utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu};
46use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS};
47use lance_core::{Error, ROW_ID, ROW_ID_FIELD, Result};
48use roaring::RoaringBitmap;
49use std::sync::LazyLock;
50use tokio::task::spawn_blocking;
51use tracing::{info, instrument};
52
53use super::encoding::PositionBlockBuilder;
54use super::iter::PostingListIterator;
55use super::{InvertedIndexBuilder, InvertedIndexParams, wand::*};
56use super::{
57    builder::{
58        BLOCK_SIZE, ScoredDoc, doc_file_path, inverted_list_schema_for_version, posting_file_path,
59        token_file_path,
60    },
61    iter::PlainPostingListIterator,
62    query::*,
63    scorer::{B, IndexBM25Scorer, K1, Scorer, idf},
64};
65use super::{
66    builder::{InnerBuilder, PositionRecorder},
67    iter::CompressedPostingListIterator,
68};
69use crate::frag_reuse::FragReuseIndex;
70use crate::pbold;
71use crate::scalar::inverted::scorer::MemBM25Scorer;
72use crate::scalar::inverted::tokenizer::document_tokenizer::LanceTokenizer;
73use crate::scalar::{
74    AnyQuery, BuiltinIndexType, CreatedIndex, IndexReader, IndexStore, MetricsCollector,
75    ScalarIndex, ScalarIndexParams, SearchResult, TokenQuery, UpdateCriteria,
76};
77use crate::{FtsPrewarmOptions, Index};
78use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys};
79use std::str::FromStr;
80
81// Version 0: Arrow TokenSetFormat (legacy)
82// Version 1: Fst TokenSetFormat with per-doc compressed positions
83// Version 2: Fst TokenSetFormat with shared posting-list position streams.
84pub const INVERTED_INDEX_VERSION_V1: u32 = 1;
85pub const INVERTED_INDEX_VERSION_V2: u32 = 2;
86pub const TOKENS_FILE: &str = "tokens.lance";
87pub const INVERT_LIST_FILE: &str = "invert.lance";
88pub const DOCS_FILE: &str = "docs.lance";
89pub const METADATA_FILE: &str = "metadata.lance";
90
91pub const TOKEN_COL: &str = "_token";
92pub const TOKEN_ID_COL: &str = "_token_id";
93pub const TOKEN_FST_BYTES_COL: &str = "_token_fst_bytes";
94pub const TOKEN_NEXT_ID_COL: &str = "_token_next_id";
95pub const TOKEN_TOTAL_LENGTH_COL: &str = "_token_total_length";
96pub const FREQUENCY_COL: &str = "_frequency";
97pub const POSITION_COL: &str = "_position";
98pub const COMPRESSED_POSITION_COL: &str = "_compressed_position";
99pub const POSITION_BLOCK_OFFSET_COL: &str = "_position_block_offset";
100pub const POSTING_COL: &str = "_posting";
101pub const MAX_SCORE_COL: &str = "_max_score";
102pub const LENGTH_COL: &str = "_length";
103pub const BLOCK_MAX_SCORE_COL: &str = "_block_max_score";
104pub const NUM_TOKEN_COL: &str = "_num_tokens";
105pub const SCORE_COL: &str = "_score";
106pub const TOKEN_SET_FORMAT_KEY: &str = "token_set_format";
107pub const POSTING_TAIL_CODEC_KEY: &str = "posting_tail_codec";
108pub const POSITIONS_LAYOUT_KEY: &str = "positions_layout";
109pub const POSITIONS_CODEC_KEY: &str = "positions_codec";
110pub const POSTING_TAIL_CODEC_FIXED32_V1: &str = "fixed32_v1";
111pub const POSTING_TAIL_CODEC_VARINT_DELTA_V1: &str = "varint_delta_v1";
112pub const POSITIONS_LAYOUT_SHARED_STREAM_V2: &str = "shared_stream_v2";
113pub const POSITIONS_CODEC_VARINT_DOC_DELTA_V2: &str = "varint_doc_delta_v2";
114pub const POSITIONS_CODEC_PACKED_DELTA_V1: &str = "packed_delta_v1";
115pub const DELETED_FRAGMENTS_COL: &str = "deleted_fragments";
116
117// Just a heuristic when we need to pre-allocate memory for tokens
118pub const ESTIMATED_MAX_TOKENS_PER_ROW: usize = 4 * 1024;
119
120pub static SCORE_FIELD: LazyLock<Field> =
121    LazyLock::new(|| Field::new(SCORE_COL, DataType::Float32, true));
122pub static FTS_SCHEMA: LazyLock<SchemaRef> =
123    LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone(), SCORE_FIELD.clone()])));
124static ROW_ID_SCHEMA: LazyLock<SchemaRef> =
125    LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()])));
126
127fn resolve_fts_format_version(
128    value: Option<&str>,
129) -> std::result::Result<InvertedListFormatVersion, Error> {
130    value.unwrap_or("1").parse()
131}
132
133pub fn current_fts_format_version() -> InvertedListFormatVersion {
134    resolve_fts_format_version(std::env::var("LANCE_FTS_FORMAT_VERSION").ok().as_deref())
135        .expect("failed to parse LANCE_FTS_FORMAT_VERSION")
136}
137
138pub fn max_supported_fts_format_version() -> InvertedListFormatVersion {
139    InvertedListFormatVersion::V2
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
143pub enum InvertedListFormatVersion {
144    #[default]
145    V1,
146    V2,
147}
148
149impl InvertedListFormatVersion {
150    pub fn from_posting_tail_codec(codec: PostingTailCodec) -> Self {
151        match codec {
152            PostingTailCodec::Fixed32 => Self::V1,
153            PostingTailCodec::VarintDelta => Self::V2,
154        }
155    }
156
157    pub fn index_version(self) -> u32 {
158        match self {
159            Self::V1 => INVERTED_INDEX_VERSION_V1,
160            Self::V2 => INVERTED_INDEX_VERSION_V2,
161        }
162    }
163
164    pub fn posting_tail_codec(self) -> PostingTailCodec {
165        match self {
166            Self::V1 => PostingTailCodec::Fixed32,
167            Self::V2 => PostingTailCodec::VarintDelta,
168        }
169    }
170
171    pub fn position_codec(self) -> Option<PositionStreamCodec> {
172        match self {
173            Self::V1 => None,
174            Self::V2 => Some(PositionStreamCodec::PackedDelta),
175        }
176    }
177
178    pub fn uses_shared_position_stream(self) -> bool {
179        matches!(self, Self::V2)
180    }
181}
182
183impl FromStr for InvertedListFormatVersion {
184    type Err = Error;
185
186    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
187        match s.trim() {
188            "1" | "v1" | "V1" => Ok(Self::V1),
189            "2" | "v2" | "V2" => Ok(Self::V2),
190            other => Err(Error::index(format!(
191                "unsupported FTS format version {}, expected 1 or 2",
192                other
193            ))),
194        }
195    }
196}
197
198#[derive(Debug)]
199struct PartitionCandidates {
200    tokens_by_position: Vec<String>,
201    candidates: Vec<DocCandidate>,
202}
203
204impl PartitionCandidates {
205    fn empty() -> Self {
206        Self {
207            tokens_by_position: Vec::new(),
208            candidates: Vec::new(),
209        }
210    }
211}
212
213#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)]
214pub enum TokenSetFormat {
215    Arrow,
216    #[default]
217    Fst,
218}
219
220impl Display for TokenSetFormat {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match self {
223            Self::Arrow => f.write_str("arrow"),
224            Self::Fst => f.write_str("fst"),
225        }
226    }
227}
228
229impl FromStr for TokenSetFormat {
230    type Err = Error;
231
232    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
233        match s.trim() {
234            "" => Ok(Self::Arrow),
235            "arrow" => Ok(Self::Arrow),
236            "fst" => Ok(Self::Fst),
237            other => Err(Error::index(format!(
238                "unsupported token set format {}",
239                other
240            ))),
241        }
242    }
243}
244
245impl DeepSizeOf for TokenSetFormat {
246    fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
247        0
248    }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
252pub enum PositionStreamCodec {
253    VarintDocDelta,
254    #[default]
255    PackedDelta,
256}
257
258#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
259pub enum PostingTailCodec {
260    Fixed32,
261    #[default]
262    VarintDelta,
263}
264
265impl PostingTailCodec {
266    pub fn as_str(self) -> &'static str {
267        match self {
268            Self::Fixed32 => POSTING_TAIL_CODEC_FIXED32_V1,
269            Self::VarintDelta => POSTING_TAIL_CODEC_VARINT_DELTA_V1,
270        }
271    }
272
273    fn from_metadata_value(value: &str) -> Result<Self> {
274        match value.trim() {
275            POSTING_TAIL_CODEC_FIXED32_V1 => Ok(Self::Fixed32),
276            POSTING_TAIL_CODEC_VARINT_DELTA_V1 => Ok(Self::VarintDelta),
277            other => Err(Error::index(format!(
278                "unsupported posting tail codec {}",
279                other
280            ))),
281        }
282    }
283}
284
285pub(super) fn parse_posting_tail_codec(
286    metadata: &HashMap<String, String>,
287) -> Result<PostingTailCodec> {
288    Ok(metadata
289        .get(POSTING_TAIL_CODEC_KEY)
290        .map(|codec| PostingTailCodec::from_metadata_value(codec))
291        .transpose()?
292        .unwrap_or(PostingTailCodec::Fixed32))
293}
294
295impl PositionStreamCodec {
296    pub fn as_str(self) -> &'static str {
297        match self {
298            Self::VarintDocDelta => POSITIONS_CODEC_VARINT_DOC_DELTA_V2,
299            Self::PackedDelta => POSITIONS_CODEC_PACKED_DELTA_V1,
300        }
301    }
302
303    fn from_metadata_value(value: &str) -> Result<Self> {
304        match value.trim() {
305            POSITIONS_CODEC_VARINT_DOC_DELTA_V2 => Ok(Self::VarintDocDelta),
306            POSITIONS_CODEC_PACKED_DELTA_V1 => Ok(Self::PackedDelta),
307            other => Err(Error::index(format!(
308                "unsupported positions codec {}",
309                other
310            ))),
311        }
312    }
313}
314
315fn parse_shared_position_codec(metadata: &HashMap<String, String>) -> Result<PositionStreamCodec> {
316    if let Some(codec) = metadata.get(POSITIONS_CODEC_KEY) {
317        return PositionStreamCodec::from_metadata_value(codec);
318    }
319
320    match metadata
321        .get(POSITIONS_LAYOUT_KEY)
322        .map(|layout| layout.as_str())
323    {
324        Some(POSITIONS_LAYOUT_SHARED_STREAM_V2) => Ok(PositionStreamCodec::VarintDocDelta),
325        _ => Ok(PositionStreamCodec::VarintDocDelta),
326    }
327}
328
329pub(super) fn parse_format_version_from_metadata(
330    metadata: &HashMap<String, String>,
331) -> Result<InvertedListFormatVersion> {
332    if metadata.contains_key(POSITIONS_CODEC_KEY) || metadata.contains_key(POSITIONS_LAYOUT_KEY) {
333        return Ok(InvertedListFormatVersion::V2);
334    }
335    if parse_posting_tail_codec(metadata)? == PostingTailCodec::VarintDelta {
336        Ok(InvertedListFormatVersion::V2)
337    } else {
338        Ok(InvertedListFormatVersion::V1)
339    }
340}
341
342#[derive(Clone)]
343pub struct InvertedIndex {
344    params: InvertedIndexParams,
345    store: Arc<dyn IndexStore>,
346    tokenizer: Box<dyn LanceTokenizer>,
347    token_set_format: TokenSetFormat,
348    pub(crate) partitions: Vec<Arc<InvertedPartition>>,
349    // Fragments which are contained in the index, but no longer in the dataset.
350    // These should be pruned at search time since we don't prune them at update time.
351    deleted_fragments: RoaringBitmap,
352}
353
354impl Debug for InvertedIndex {
355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356        f.debug_struct("InvertedIndex")
357            .field("params", &self.params)
358            .field("token_set_format", &self.token_set_format)
359            .field("partitions", &self.partitions)
360            .field("deleted_fragments", &self.deleted_fragments)
361            .finish()
362    }
363}
364
365impl DeepSizeOf for InvertedIndex {
366    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
367        self.partitions.deep_size_of_children(context)
368    }
369}
370
371impl InvertedIndex {
372    fn format_version(&self) -> InvertedListFormatVersion {
373        self.partitions
374            .first()
375            .map(|partition| {
376                InvertedListFormatVersion::from_posting_tail_codec(
377                    partition.inverted_list.posting_tail_codec(),
378                )
379            })
380            .unwrap_or_else(current_fts_format_version)
381    }
382
383    fn index_version(&self) -> u32 {
384        match self.token_set_format {
385            TokenSetFormat::Arrow => 0,
386            TokenSetFormat::Fst => self.format_version().index_version(),
387        }
388    }
389
390    fn posting_tail_codec(&self) -> PostingTailCodec {
391        self.partitions
392            .first()
393            .map(|partition| partition.inverted_list.posting_tail_codec())
394            .unwrap_or_default()
395    }
396
397    fn to_builder(&self) -> InvertedIndexBuilder {
398        self.to_builder_with_offset(None)
399    }
400
401    fn to_builder_with_offset(&self, fragment_mask: Option<u64>) -> InvertedIndexBuilder {
402        if self.is_legacy() {
403            // for legacy format, we re-create the index in the new format
404            InvertedIndexBuilder::from_existing_index(
405                self.params.clone(),
406                None,
407                Vec::new(),
408                self.token_set_format,
409                fragment_mask,
410                self.deleted_fragments.clone(),
411            )
412            .with_posting_tail_codec(self.posting_tail_codec())
413        } else {
414            let partitions = match fragment_mask {
415                Some(fragment_mask) => self
416                    .partitions
417                    .iter()
418                    // Filter partitions that belong to the specified fragment
419                    // The mask contains fragment_id in high 32 bits, we check if partition's
420                    // fragment_id matches by comparing the masked result with the original mask
421                    .filter(|part| part.belongs_to_fragment(fragment_mask))
422                    .map(|part| part.id())
423                    .collect(),
424                None => self.partitions.iter().map(|part| part.id()).collect(),
425            };
426
427            InvertedIndexBuilder::from_existing_index(
428                self.params.clone(),
429                Some(self.store.clone()),
430                partitions,
431                self.token_set_format,
432                fragment_mask,
433                self.deleted_fragments.clone(),
434            )
435            .with_format_version(self.format_version())
436        }
437    }
438
439    pub fn tokenizer(&self) -> Box<dyn LanceTokenizer> {
440        self.tokenizer.clone()
441    }
442
443    pub fn params(&self) -> &InvertedIndexParams {
444        &self.params
445    }
446
447    /// Returns the number of partitions in this inverted index.
448    pub fn partition_count(&self) -> usize {
449        self.partitions.len()
450    }
451    /// Returns the set of fragments which are contained in the index, but no longer in the dataset.
452    ///
453    /// Most other indices remove data from deleted fragments when the index updates (copy-on-write).
454    /// However, this would require an expensive copy of the FTS index.  Instead, we track the deleted
455    /// fragments and prune them at search time (merge-on-read).
456    pub fn deleted_fragments(&self) -> &RoaringBitmap {
457        &self.deleted_fragments
458    }
459
460    pub fn bm25_base_scorer(&self, query_tokens: &Tokens) -> MemBM25Scorer {
461        let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
462        let token_docs = query_tokens
463            .into_iter()
464            .map(|token| (token.to_string(), scorer.num_docs_containing_token(token)))
465            .collect::<HashMap<_, _>>();
466        MemBM25Scorer::new(scorer.total_tokens(), scorer.num_docs(), token_docs)
467    }
468
469    pub fn bm25_stats_for_terms(&self, terms: &[String]) -> (u64, usize, Vec<usize>) {
470        let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
471        let token_docs = terms
472            .iter()
473            .map(|term| scorer.num_docs_containing_token(term))
474            .collect();
475        (scorer.total_tokens(), scorer.num_docs(), token_docs)
476    }
477
478    /// Expand fuzzy query tokens against all partitions in this segment.
479    pub fn expand_fuzzy_tokens(&self, tokens: &Tokens, params: &FtsSearchParams) -> Result<Tokens> {
480        let mut expanded_tokens = Vec::new();
481        let mut expanded_positions = Vec::new();
482        let mut seen = HashSet::new();
483        for partition in &self.partitions {
484            let expanded = partition.expand_fuzzy(tokens, params)?;
485            for idx in 0..expanded.len() {
486                let token = expanded.get_token(idx);
487                if seen.insert(token.to_string()) {
488                    expanded_tokens.push(token.to_string());
489                    expanded_positions.push(expanded.position(idx));
490                }
491            }
492        }
493        Ok(Tokens::with_positions(
494            expanded_tokens,
495            expanded_positions,
496            tokens.token_type().clone(),
497        ))
498    }
499
500    /// Search documents that match the query and return row ids sorted by BM25 score.
501    ///
502    /// When `base_scorer` is provided, search uses those corpus-level BM25 statistics
503    /// instead of deriving them from this segment alone.
504    #[instrument(level = "debug", skip_all)]
505    pub async fn bm25_search(
506        &self,
507        tokens: Arc<Tokens>,
508        params: Arc<FtsSearchParams>,
509        operator: Operator,
510        prefilter: Arc<dyn PreFilter>,
511        metrics: Arc<dyn MetricsCollector>,
512        base_scorer: Option<&MemBM25Scorer>,
513    ) -> Result<(Vec<u64>, Vec<f32>)> {
514        let local_scorer;
515        let scorer: &dyn Scorer = if let Some(base_scorer) = base_scorer {
516            base_scorer
517        } else {
518            local_scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
519            &local_scorer
520        };
521
522        let limit = params.limit.unwrap_or(usize::MAX);
523        if limit == 0 {
524            return Ok((Vec::new(), Vec::new()));
525        }
526        let mask = prefilter.mask();
527
528        let mut candidates = BinaryHeap::new();
529        let parts = self
530            .partitions
531            .iter()
532            .map(|part| {
533                let part = part.clone();
534                let tokens = tokens.clone();
535                let params = params.clone();
536                let mask = mask.clone();
537                let metrics = metrics.clone();
538                async move {
539                    let postings = part
540                        .load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref())
541                        .await?;
542                    if postings.is_empty() {
543                        return Result::Ok(PartitionCandidates::empty());
544                    }
545                    let max_position = postings
546                        .iter()
547                        .map(|posting| posting.term_index() as usize)
548                        .max()
549                        .unwrap_or_default();
550                    let mut tokens_by_position = vec![String::new(); max_position + 1];
551                    for posting in &postings {
552                        let idx = posting.term_index() as usize;
553                        tokens_by_position[idx] = posting.token().to_owned();
554                    }
555                    let params = params.clone();
556                    let mask = mask.clone();
557                    let metrics = metrics.clone();
558                    spawn_cpu(move || {
559                        let candidates = part.bm25_search(
560                            params.as_ref(),
561                            operator,
562                            mask,
563                            postings,
564                            metrics.as_ref(),
565                        )?;
566                        Ok(PartitionCandidates {
567                            tokens_by_position,
568                            candidates,
569                        })
570                    })
571                    .await
572                }
573            })
574            .collect::<Vec<_>>();
575        let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
576        let mut idf_cache: HashMap<String, f32> = HashMap::new();
577        while let Some(res) = parts.try_next().await? {
578            if res.candidates.is_empty() {
579                continue;
580            }
581            let mut idf_by_position = Vec::with_capacity(res.tokens_by_position.len());
582            for token in &res.tokens_by_position {
583                let idf_weight = match idf_cache.get(token) {
584                    Some(weight) => *weight,
585                    None => {
586                        let weight = scorer.query_weight(token);
587                        idf_cache.insert(token.clone(), weight);
588                        weight
589                    }
590                };
591                idf_by_position.push(idf_weight);
592            }
593            for DocCandidate {
594                row_id,
595                freqs,
596                doc_length,
597            } in res.candidates
598            {
599                let mut score = 0.0;
600                for (term_index, freq) in freqs.into_iter() {
601                    debug_assert!((term_index as usize) < idf_by_position.len());
602                    score +=
603                        idf_by_position[term_index as usize] * scorer.doc_weight(freq, doc_length);
604                }
605                if candidates.len() < limit {
606                    candidates.push(Reverse(ScoredDoc::new(row_id, score)));
607                } else if candidates.peek().unwrap().0.score.0 < score {
608                    candidates.pop();
609                    candidates.push(Reverse(ScoredDoc::new(row_id, score)));
610                }
611            }
612        }
613
614        Ok(candidates
615            .into_sorted_vec()
616            .into_iter()
617            .map(|Reverse(doc)| (doc.row_id, doc.score.0))
618            .unzip())
619    }
620
621    async fn load_legacy_index(
622        store: Arc<dyn IndexStore>,
623        frag_reuse_index: Option<Arc<FragReuseIndex>>,
624        index_cache: &LanceCache,
625    ) -> Result<Arc<Self>> {
626        log::warn!("loading legacy FTS index");
627        let tokens_fut = tokio::spawn({
628            let store = store.clone();
629            async move {
630                let token_reader = store.open_index_file(TOKENS_FILE).await?;
631                let tokenizer = token_reader
632                    .schema()
633                    .metadata
634                    .get("tokenizer")
635                    .map(|s| serde_json::from_str::<InvertedIndexParams>(s))
636                    .transpose()?
637                    .unwrap_or_default();
638                let tokens = TokenSet::load(token_reader, TokenSetFormat::Arrow).await?;
639                Result::Ok((tokenizer, tokens))
640            }
641        });
642        let invert_list_fut = tokio::spawn({
643            let store = store.clone();
644            let index_cache_clone = index_cache.clone();
645            async move {
646                let invert_list_reader = store.open_index_file(INVERT_LIST_FILE).await?;
647                let invert_list =
648                    PostingListReader::try_new(invert_list_reader, &index_cache_clone).await?;
649                Result::Ok(Arc::new(invert_list))
650            }
651        });
652        let docs_fut = tokio::spawn({
653            let store = store.clone();
654            async move {
655                let docs_reader = store.open_index_file(DOCS_FILE).await?;
656                let docs = DocSet::load(docs_reader, true, frag_reuse_index).await?;
657                Result::Ok(docs)
658            }
659        });
660
661        let (tokenizer_config, tokens) = tokens_fut.await??;
662        let inverted_list = invert_list_fut.await??;
663        let docs = docs_fut.await??;
664
665        let tokenizer = tokenizer_config.build()?;
666
667        Ok(Arc::new(Self {
668            params: tokenizer_config,
669            store: store.clone(),
670            tokenizer,
671            token_set_format: TokenSetFormat::Arrow,
672            partitions: vec![Arc::new(InvertedPartition {
673                id: 0,
674                store,
675                tokens,
676                inverted_list,
677                docs,
678                token_set_format: TokenSetFormat::Arrow,
679            })],
680            deleted_fragments: RoaringBitmap::new(),
681        }))
682    }
683
684    pub fn is_legacy(&self) -> bool {
685        self.partitions.len() == 1 && self.partitions[0].is_legacy()
686    }
687
688    pub async fn load(
689        store: Arc<dyn IndexStore>,
690        frag_reuse_index: Option<Arc<FragReuseIndex>>,
691        index_cache: &LanceCache,
692    ) -> Result<Arc<Self>>
693    where
694        Self: Sized,
695    {
696        // for new index format, there is a metadata file and multiple partitions,
697        // each partition is a separate index containing tokens, inverted list and docs.
698        // for old index format, there is no metadata file, and it's just like a single partition
699
700        match store.open_index_file(METADATA_FILE).await {
701            Ok(reader) => {
702                let params = reader
703                    .schema()
704                    .metadata
705                    .get("params")
706                    .ok_or(Error::index("params not found in metadata".to_owned()))?;
707                let params = serde_json::from_str::<InvertedIndexParams>(params)?;
708                let partitions = reader
709                    .schema()
710                    .metadata
711                    .get("partitions")
712                    .ok_or(Error::index("partitions not found in metadata".to_owned()))?;
713                let partitions: Vec<u64> = serde_json::from_str(partitions)?;
714                let token_set_format = reader
715                    .schema()
716                    .metadata
717                    .get(TOKEN_SET_FORMAT_KEY)
718                    .map(|name| TokenSetFormat::from_str(name))
719                    .transpose()?
720                    .unwrap_or(TokenSetFormat::Arrow);
721
722                // Load deleted_fragments if present (optional for backward compatibility)
723                let deleted_fragments = if reader.num_rows() > 0 {
724                    let metadata_batch = reader.read_range(0..1, None).await?;
725                    if let Some(col) = metadata_batch.column_by_name(DELETED_FRAGMENTS_COL) {
726                        let arr = col.as_binary_opt::<i32>().expect_ok()?;
727                        RoaringBitmap::deserialize_from(arr.value(0))?
728                    } else {
729                        RoaringBitmap::new()
730                    }
731                } else {
732                    RoaringBitmap::new()
733                };
734
735                let format = token_set_format;
736                let partitions = partitions.into_iter().map(|id| {
737                    let store = store.clone();
738                    let frag_reuse_index_clone = frag_reuse_index.clone();
739                    let index_cache_for_part =
740                        index_cache.with_key_prefix(format!("part-{}", id).as_str());
741                    let token_set_format = format;
742                    async move {
743                        Result::Ok(Arc::new(
744                            InvertedPartition::load(
745                                store,
746                                id,
747                                frag_reuse_index_clone,
748                                &index_cache_for_part,
749                                token_set_format,
750                            )
751                            .await?,
752                        ))
753                    }
754                });
755                let partitions = stream::iter(partitions)
756                    .buffer_unordered(store.io_parallelism())
757                    .try_collect::<Vec<_>>()
758                    .await?;
759
760                let tokenizer = params.build()?;
761                Ok(Arc::new(Self {
762                    params,
763                    store,
764                    tokenizer,
765                    token_set_format,
766                    partitions,
767                    deleted_fragments,
768                }))
769            }
770            Err(_) => {
771                // old index format
772                Self::load_legacy_index(store, frag_reuse_index, index_cache).await
773            }
774        }
775    }
776}
777
778#[async_trait]
779impl Index for InvertedIndex {
780    fn as_any(&self) -> &dyn std::any::Any {
781        self
782    }
783
784    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
785        self
786    }
787
788    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
789        Err(Error::invalid_input(
790            "inverted index cannot be cast to vector index",
791        ))
792    }
793
794    fn statistics(&self) -> Result<serde_json::Value> {
795        let num_tokens = self
796            .partitions
797            .iter()
798            .map(|part| part.tokens.len())
799            .sum::<usize>();
800        let num_docs = self
801            .partitions
802            .iter()
803            .map(|part| part.docs.len())
804            .sum::<usize>();
805        Ok(serde_json::json!({
806            "params": self.params,
807            "num_tokens": num_tokens,
808            "num_docs": num_docs,
809        }))
810    }
811
812    async fn prewarm(&self) -> Result<()> {
813        self.prewarm_with_options(&FtsPrewarmOptions::default())
814            .await
815    }
816
817    fn index_type(&self) -> crate::IndexType {
818        crate::IndexType::Inverted
819    }
820
821    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
822        unimplemented!()
823    }
824}
825
826impl InvertedIndex {
827    pub async fn prewarm_with_options(&self, options: &FtsPrewarmOptions) -> Result<()> {
828        let with_position = options.with_position;
829        let io_parallelism = self.store.io_parallelism();
830        let prewarm_futures = self
831            .partitions
832            .iter()
833            .map(Arc::clone)
834            .map(|part| async move {
835                part.inverted_list
836                    .prewarm_posting_lists(with_position)
837                    .await?;
838                Result::Ok(())
839            });
840        stream::iter(prewarm_futures)
841            .buffer_unordered(io_parallelism)
842            .try_collect::<Vec<_>>()
843            .await?;
844        Ok(())
845    }
846    /// Search docs match the input text.
847    async fn do_search(&self, text: &str) -> Result<RecordBatch> {
848        let params = FtsSearchParams::new();
849        let mut tokenizer = self.tokenizer.clone();
850        let tokens = collect_query_tokens(text, &mut tokenizer);
851
852        let (doc_ids, _) = self
853            .bm25_search(
854                Arc::new(tokens),
855                params.into(),
856                Operator::And,
857                Arc::new(NoFilter),
858                Arc::new(NoOpMetricsCollector),
859                None,
860            )
861            .boxed()
862            .await?;
863
864        Ok(RecordBatch::try_new(
865            ROW_ID_SCHEMA.clone(),
866            vec![Arc::new(UInt64Array::from(doc_ids))],
867        )?)
868    }
869}
870
871#[async_trait]
872impl ScalarIndex for InvertedIndex {
873    // return the row ids of the documents that contain the query
874    #[instrument(level = "debug", skip_all)]
875    async fn search(
876        &self,
877        query: &dyn AnyQuery,
878        _metrics: &dyn MetricsCollector,
879    ) -> Result<SearchResult> {
880        let query = query.as_any().downcast_ref::<TokenQuery>().unwrap();
881
882        match query {
883            TokenQuery::TokensContains(text) => {
884                let records = self.do_search(text).await?;
885                let row_ids = records
886                    .column(0)
887                    .as_any()
888                    .downcast_ref::<UInt64Array>()
889                    .unwrap();
890                let row_ids = row_ids.iter().flatten().collect_vec();
891                Ok(SearchResult::at_most(RowAddrTreeMap::from_iter(row_ids)))
892            }
893        }
894    }
895
896    fn can_remap(&self) -> bool {
897        true
898    }
899
900    async fn remap(
901        &self,
902        mapping: &HashMap<u64, Option<u64>>,
903        dest_store: &dyn IndexStore,
904    ) -> Result<CreatedIndex> {
905        self.to_builder()
906            .remap(mapping, self.store.clone(), dest_store)
907            .await?;
908
909        let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
910
911        Ok(CreatedIndex {
912            index_details: prost_types::Any::from_msg(&details).unwrap(),
913            index_version: self.index_version(),
914            files: Some(dest_store.list_files_with_sizes().await?),
915        })
916    }
917
918    async fn update(
919        &self,
920        new_data: SendableRecordBatchStream,
921        dest_store: &dyn IndexStore,
922        old_data_filter: Option<crate::scalar::OldIndexDataFilter>,
923    ) -> Result<CreatedIndex> {
924        self.to_builder()
925            .update(new_data, dest_store, old_data_filter)
926            .await?;
927
928        let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
929
930        Ok(CreatedIndex {
931            index_details: prost_types::Any::from_msg(&details).unwrap(),
932            index_version: self.index_version(),
933            files: Some(dest_store.list_files_with_sizes().await?),
934        })
935    }
936
937    fn update_criteria(&self) -> UpdateCriteria {
938        let criteria = TrainingCriteria::new(TrainingOrdering::None).with_row_id();
939        if self.is_legacy() {
940            UpdateCriteria::requires_old_data(criteria)
941        } else {
942            UpdateCriteria::only_new_data(criteria)
943        }
944    }
945
946    fn derive_index_params(&self) -> Result<ScalarIndexParams> {
947        let mut params = self.params.clone();
948        if params.base_tokenizer.is_empty() {
949            params.base_tokenizer = "simple".to_string();
950        }
951
952        let params_json = serde_json::to_string(&params)?;
953
954        Ok(ScalarIndexParams {
955            index_type: BuiltinIndexType::Inverted.as_str().to_string(),
956            params: Some(params_json),
957        })
958    }
959}
960
961#[derive(Debug, Clone, DeepSizeOf)]
962pub struct InvertedPartition {
963    // 0 for legacy format
964    id: u64,
965    store: Arc<dyn IndexStore>,
966    pub(crate) tokens: TokenSet,
967    pub(crate) inverted_list: Arc<PostingListReader>,
968    pub(crate) docs: DocSet,
969    token_set_format: TokenSetFormat,
970}
971
972impl InvertedPartition {
973    /// Check if this partition belongs to the specified fragment.
974    ///
975    /// This method encapsulates the bit manipulation logic for fragment filtering
976    /// in distributed indexing scenarios.
977    ///
978    /// # Arguments
979    /// * `fragment_mask` - A mask with fragment_id in high 32 bits
980    ///
981    /// # Returns
982    /// * `true` if the partition belongs to the fragment, `false` otherwise
983    pub fn belongs_to_fragment(&self, fragment_mask: u64) -> bool {
984        (self.id() & fragment_mask) == fragment_mask
985    }
986
987    pub fn id(&self) -> u64 {
988        self.id
989    }
990
991    pub fn store(&self) -> &dyn IndexStore {
992        self.store.as_ref()
993    }
994
995    pub fn is_legacy(&self) -> bool {
996        self.inverted_list.lengths.is_none()
997    }
998
999    pub async fn load(
1000        store: Arc<dyn IndexStore>,
1001        id: u64,
1002        frag_reuse_index: Option<Arc<FragReuseIndex>>,
1003        index_cache: &LanceCache,
1004        token_set_format: TokenSetFormat,
1005    ) -> Result<Self> {
1006        let token_file = store.open_index_file(&token_file_path(id)).await?;
1007        let tokens = TokenSet::load(token_file, token_set_format).await?;
1008        let invert_list_file = store.open_index_file(&posting_file_path(id)).await?;
1009        let inverted_list = PostingListReader::try_new(invert_list_file, index_cache).await?;
1010        let docs_file = store.open_index_file(&doc_file_path(id)).await?;
1011        let docs = DocSet::load(docs_file, false, frag_reuse_index).await?;
1012
1013        Ok(Self {
1014            id,
1015            store,
1016            tokens,
1017            inverted_list: Arc::new(inverted_list),
1018            docs,
1019            token_set_format,
1020        })
1021    }
1022
1023    fn map(&self, token: &str) -> Option<u32> {
1024        self.tokens.get(token)
1025    }
1026
1027    pub fn expand_fuzzy(&self, tokens: &Tokens, params: &FtsSearchParams) -> Result<Tokens> {
1028        let mut new_tokens = Vec::with_capacity(min(tokens.len(), params.max_expansions));
1029        for token in tokens {
1030            let fuzziness = match params.fuzziness {
1031                Some(fuzziness) => fuzziness,
1032                None => MatchQuery::auto_fuzziness(token),
1033            };
1034            let lev = fst::automaton::Levenshtein::new(token, fuzziness)
1035                .map_err(|e| Error::index(format!("failed to construct the fuzzy query: {}", e)))?;
1036
1037            let base_len = tokens.token_type().prefix_len(token) as u32;
1038            if let TokenMap::Fst(ref map) = self.tokens.tokens {
1039                match base_len + params.prefix_length {
1040                    0 => take_fst_keys(map.search(lev), &mut new_tokens, params.max_expansions),
1041                    prefix_length => {
1042                        let prefix = &token[..min(prefix_length as usize, token.len())];
1043                        let prefix = fst::automaton::Str::new(prefix).starts_with();
1044                        take_fst_keys(
1045                            map.search(lev.intersection(prefix)),
1046                            &mut new_tokens,
1047                            params.max_expansions,
1048                        )
1049                    }
1050                }
1051            } else {
1052                return Err(Error::index(
1053                    "tokens is not fst, which is not expected".to_owned(),
1054                ));
1055            }
1056        }
1057        Ok(Tokens::new(new_tokens, tokens.token_type().clone()))
1058    }
1059
1060    // search the documents that contain the query
1061    // return the doc info and the doc length
1062    // ref: https://en.wikipedia.org/wiki/Okapi_BM25
1063    #[instrument(level = "debug", skip_all)]
1064    pub async fn load_posting_lists(
1065        &self,
1066        tokens: &Tokens,
1067        params: &FtsSearchParams,
1068        metrics: &dyn MetricsCollector,
1069    ) -> Result<Vec<PostingIterator>> {
1070        let is_fuzzy = matches!(params.fuzziness, Some(n) if n != 0);
1071        let is_phrase_query = params.phrase_slop.is_some();
1072        let tokens = match is_fuzzy {
1073            true => self.expand_fuzzy(tokens, params)?,
1074            false => tokens.clone(),
1075        };
1076        let token_positions = (0..tokens.len())
1077            .map(|index| tokens.position(index))
1078            .collect::<Vec<_>>();
1079        let mut token_ids = Vec::with_capacity(tokens.len());
1080        for (index, token) in tokens.into_iter().enumerate() {
1081            let token_id = self.map(&token);
1082            if let Some(token_id) = token_id {
1083                token_ids.push((token_id, token, token_positions[index]));
1084            } else if is_phrase_query {
1085                // if the token is not found, we can't do phrase query
1086                return Ok(Vec::new());
1087            }
1088        }
1089        if token_ids.is_empty() {
1090            return Ok(Vec::new());
1091        }
1092        if !is_phrase_query {
1093            token_ids.sort_unstable_by_key(|(token_id, _, _)| *token_id);
1094            token_ids.dedup_by_key(|(token_id, _, _)| *token_id);
1095        }
1096
1097        let num_docs = self.docs.len();
1098        stream::iter(token_ids)
1099            .map(|(token_id, token, position)| async move {
1100                let posting = self
1101                    .inverted_list
1102                    .posting_list(token_id, is_phrase_query, metrics)
1103                    .await?;
1104
1105                let query_weight = idf(posting.len(), num_docs);
1106
1107                Result::Ok(PostingIterator::with_query_weight(
1108                    token,
1109                    token_id,
1110                    position,
1111                    query_weight,
1112                    posting,
1113                    num_docs,
1114                ))
1115            })
1116            .buffered(self.store.io_parallelism())
1117            .try_collect::<Vec<_>>()
1118            .await
1119    }
1120
1121    #[instrument(level = "debug", skip_all)]
1122    pub fn bm25_search(
1123        &self,
1124        params: &FtsSearchParams,
1125        operator: Operator,
1126        mask: Arc<RowAddrMask>,
1127        postings: Vec<PostingIterator>,
1128        metrics: &dyn MetricsCollector,
1129    ) -> Result<Vec<DocCandidate>> {
1130        if postings.is_empty() {
1131            return Ok(Vec::new());
1132        }
1133
1134        // let local_metrics = LocalMetricsCollector::default();
1135        let scorer = IndexBM25Scorer::new(std::iter::once(self));
1136        let mut wand = Wand::new(operator, postings.into_iter(), &self.docs, scorer);
1137        let hits = wand.search(params, mask, metrics)?;
1138        // local_metrics.dump_into(metrics);
1139        Ok(hits)
1140    }
1141
1142    pub async fn into_builder(self) -> Result<InnerBuilder> {
1143        let mut builder = InnerBuilder::new_with_posting_tail_codec(
1144            self.id,
1145            self.inverted_list.has_positions(),
1146            self.token_set_format,
1147            self.inverted_list.posting_tail_codec(),
1148        );
1149        builder.tokens = self.tokens;
1150        builder.docs = self.docs;
1151
1152        builder
1153            .posting_lists
1154            .reserve_exact(self.inverted_list.len());
1155        for posting_list in self
1156            .inverted_list
1157            .read_all(self.inverted_list.has_positions())
1158            .await?
1159        {
1160            let posting_list = posting_list?;
1161            builder
1162                .posting_lists
1163                .push(posting_list.into_builder(&builder.docs));
1164        }
1165        Ok(builder)
1166    }
1167}
1168
1169// at indexing, we use HashMap because we need it to be mutable,
1170// at searching, we use fst::Map because it's more efficient
1171#[derive(Debug, Clone)]
1172pub enum TokenMap {
1173    HashMap(HashMap<String, u32>),
1174    Fst(fst::Map<Vec<u8>>),
1175}
1176
1177impl Default for TokenMap {
1178    fn default() -> Self {
1179        Self::HashMap(HashMap::new())
1180    }
1181}
1182
1183impl DeepSizeOf for TokenMap {
1184    fn deep_size_of_children(&self, ctx: &mut deepsize::Context) -> usize {
1185        match self {
1186            Self::HashMap(map) => map.deep_size_of_children(ctx),
1187            Self::Fst(map) => map.as_fst().size(),
1188        }
1189    }
1190}
1191
1192impl TokenMap {
1193    pub fn len(&self) -> usize {
1194        match self {
1195            Self::HashMap(map) => map.len(),
1196            Self::Fst(map) => map.len(),
1197        }
1198    }
1199
1200    pub fn is_empty(&self) -> bool {
1201        self.len() == 0
1202    }
1203}
1204
1205// TokenSet is a mapping from tokens to token ids
1206#[derive(Debug, Clone, Default, DeepSizeOf)]
1207pub struct TokenSet {
1208    // token -> token_id
1209    pub(crate) tokens: TokenMap,
1210    pub(crate) next_id: u32,
1211    total_length: usize,
1212}
1213
1214impl TokenSet {
1215    pub fn into_mut(self) -> Self {
1216        let tokens = match self.tokens {
1217            TokenMap::HashMap(map) => map,
1218            TokenMap::Fst(map) => {
1219                let mut new_map = HashMap::with_capacity(map.len());
1220                let mut stream = map.into_stream();
1221                while let Some((token, token_id)) = stream.next() {
1222                    new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
1223                }
1224
1225                new_map
1226            }
1227        };
1228
1229        Self {
1230            tokens: TokenMap::HashMap(tokens),
1231            next_id: self.next_id,
1232            total_length: self.total_length,
1233        }
1234    }
1235
1236    pub fn len(&self) -> usize {
1237        self.tokens.len()
1238    }
1239
1240    pub fn is_empty(&self) -> bool {
1241        self.len() == 0
1242    }
1243
1244    pub fn to_batch(self, format: TokenSetFormat) -> Result<RecordBatch> {
1245        match format {
1246            TokenSetFormat::Arrow => self.into_arrow_batch(),
1247            TokenSetFormat::Fst => self.into_fst_batch(),
1248        }
1249    }
1250
1251    fn into_arrow_batch(self) -> Result<RecordBatch> {
1252        let mut token_builder = StringBuilder::with_capacity(self.tokens.len(), self.total_length);
1253        let mut token_id_builder = UInt32Builder::with_capacity(self.tokens.len());
1254
1255        match self.tokens {
1256            TokenMap::Fst(map) => {
1257                let mut stream = map.stream();
1258                while let Some((token, token_id)) = stream.next() {
1259                    token_builder.append_value(String::from_utf8_lossy(token));
1260                    token_id_builder.append_value(token_id as u32);
1261                }
1262            }
1263            TokenMap::HashMap(map) => {
1264                for (token, token_id) in map.into_iter().sorted_unstable() {
1265                    token_builder.append_value(token);
1266                    token_id_builder.append_value(token_id);
1267                }
1268            }
1269        }
1270
1271        let token_col = token_builder.finish();
1272        let token_id_col = token_id_builder.finish();
1273
1274        let schema = arrow_schema::Schema::new(vec![
1275            arrow_schema::Field::new(TOKEN_COL, DataType::Utf8, false),
1276            arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false),
1277        ]);
1278
1279        let batch = RecordBatch::try_new(
1280            Arc::new(schema),
1281            vec![
1282                Arc::new(token_col) as ArrayRef,
1283                Arc::new(token_id_col) as ArrayRef,
1284            ],
1285        )?;
1286        Ok(batch)
1287    }
1288
1289    fn into_fst_batch(mut self) -> Result<RecordBatch> {
1290        let fst_map = match std::mem::take(&mut self.tokens) {
1291            TokenMap::Fst(map) => map,
1292            TokenMap::HashMap(map) => Self::build_fst_from_map(map)?,
1293        };
1294        let bytes = fst_map.into_fst().into_inner();
1295
1296        let mut fst_builder = LargeBinaryBuilder::with_capacity(1, bytes.len());
1297        fst_builder.append_value(bytes);
1298        let fst_col = fst_builder.finish();
1299
1300        let mut next_id_builder = UInt32Builder::with_capacity(1);
1301        next_id_builder.append_value(self.next_id);
1302        let next_id_col = next_id_builder.finish();
1303
1304        let mut total_length_builder = UInt64Builder::with_capacity(1);
1305        total_length_builder.append_value(self.total_length as u64);
1306        let total_length_col = total_length_builder.finish();
1307
1308        let schema = arrow_schema::Schema::new(vec![
1309            arrow_schema::Field::new(TOKEN_FST_BYTES_COL, DataType::LargeBinary, false),
1310            arrow_schema::Field::new(TOKEN_NEXT_ID_COL, DataType::UInt32, false),
1311            arrow_schema::Field::new(TOKEN_TOTAL_LENGTH_COL, DataType::UInt64, false),
1312        ]);
1313
1314        let batch = RecordBatch::try_new(
1315            Arc::new(schema),
1316            vec![
1317                Arc::new(fst_col) as ArrayRef,
1318                Arc::new(next_id_col) as ArrayRef,
1319                Arc::new(total_length_col) as ArrayRef,
1320            ],
1321        )?;
1322        Ok(batch)
1323    }
1324
1325    fn build_fst_from_map(map: HashMap<String, u32>) -> Result<fst::Map<Vec<u8>>> {
1326        let mut entries: Vec<_> = map.into_iter().collect();
1327        entries.sort_unstable_by(|(lhs, _), (rhs, _)| lhs.cmp(rhs));
1328        let mut builder = fst::MapBuilder::memory();
1329        for (token, token_id) in entries {
1330            builder
1331                .insert(&token, token_id as u64)
1332                .map_err(|e| Error::index(format!("failed to insert token {}: {}", token, e)))?;
1333        }
1334        Ok(builder.into_map())
1335    }
1336
1337    pub async fn load(reader: Arc<dyn IndexReader>, format: TokenSetFormat) -> Result<Self> {
1338        match format {
1339            TokenSetFormat::Arrow => Self::load_arrow(reader).await,
1340            TokenSetFormat::Fst => Self::load_fst(reader).await,
1341        }
1342    }
1343
1344    async fn load_arrow(reader: Arc<dyn IndexReader>) -> Result<Self> {
1345        let batch = reader.read_range(0..reader.num_rows(), None).await?;
1346
1347        let (tokens, next_id, total_length) = spawn_blocking(move || {
1348            let mut next_id = 0;
1349            let mut total_length = 0;
1350            let mut tokens = fst::MapBuilder::memory();
1351
1352            let token_col = batch[TOKEN_COL].as_string::<i32>();
1353            let token_id_col = batch[TOKEN_ID_COL].as_primitive::<datatypes::UInt32Type>();
1354
1355            for (token, &token_id) in token_col.iter().zip(token_id_col.values().iter()) {
1356                let token =
1357                    token.ok_or(Error::index("found null token in token set".to_owned()))?;
1358                next_id = next_id.max(token_id + 1);
1359                total_length += token.len();
1360                tokens.insert(token, token_id as u64).map_err(|e| {
1361                    Error::index(format!("failed to insert token {}: {}", token, e))
1362                })?;
1363            }
1364
1365            Ok::<_, Error>((tokens.into_map(), next_id, total_length))
1366        })
1367        .await
1368        .map_err(|err| Error::execution(format!("failed to spawn blocking task: {}", err)))??;
1369
1370        Ok(Self {
1371            tokens: TokenMap::Fst(tokens),
1372            next_id,
1373            total_length,
1374        })
1375    }
1376
1377    async fn load_fst(reader: Arc<dyn IndexReader>) -> Result<Self> {
1378        let batch = reader.read_range(0..reader.num_rows(), None).await?;
1379        if batch.num_rows() == 0 {
1380            return Err(Error::index("token set batch is empty".to_owned()));
1381        }
1382
1383        let fst_col = batch[TOKEN_FST_BYTES_COL].as_binary::<i64>();
1384        let bytes = fst_col.value(0);
1385        let map = fst::Map::new(bytes.to_vec())
1386            .map_err(|e| Error::index(format!("failed to load fst tokens: {}", e)))?;
1387
1388        let next_id_col = batch[TOKEN_NEXT_ID_COL].as_primitive::<datatypes::UInt32Type>();
1389        let total_length_col =
1390            batch[TOKEN_TOTAL_LENGTH_COL].as_primitive::<datatypes::UInt64Type>();
1391
1392        let next_id = next_id_col
1393            .values()
1394            .first()
1395            .copied()
1396            .ok_or(Error::index("token next id column is empty".to_owned()))?;
1397
1398        let total_length = total_length_col
1399            .values()
1400            .first()
1401            .copied()
1402            .ok_or(Error::index(
1403                "token total length column is empty".to_owned(),
1404            ))?;
1405
1406        Ok(Self {
1407            tokens: TokenMap::Fst(map),
1408            next_id,
1409            total_length: usize::try_from(total_length).map_err(|_| {
1410                Error::index(format!(
1411                    "token total length {} overflows usize",
1412                    total_length
1413                ))
1414            })?,
1415        })
1416    }
1417
1418    pub fn add(&mut self, token: String) -> u32 {
1419        let next_id = self.next_id();
1420        let len = token.len();
1421        let token_id = match self.tokens {
1422            TokenMap::HashMap(ref mut map) => *map.entry(token).or_insert(next_id),
1423            _ => unreachable!("tokens must be HashMap while indexing"),
1424        };
1425
1426        // add token if it doesn't exist
1427        if token_id == next_id {
1428            self.next_id += 1;
1429            self.total_length += len;
1430        }
1431
1432        token_id
1433    }
1434
1435    pub(crate) fn get_or_add(&mut self, token: &str) -> u32 {
1436        let next_id = self.next_id;
1437        match self.tokens {
1438            TokenMap::HashMap(ref mut map) => {
1439                if let Some(&token_id) = map.get(token) {
1440                    return token_id;
1441                }
1442
1443                map.insert(token.to_owned(), next_id);
1444            }
1445            _ => unreachable!("tokens must be HashMap while indexing"),
1446        }
1447
1448        self.next_id += 1;
1449        self.total_length += token.len();
1450        next_id
1451    }
1452
1453    pub fn get(&self, token: &str) -> Option<u32> {
1454        match self.tokens {
1455            TokenMap::HashMap(ref map) => map.get(token).copied(),
1456            TokenMap::Fst(ref map) => map.get(token).map(|id| id as u32),
1457        }
1458    }
1459
1460    // the `removed_token_ids` must be sorted
1461    pub fn remap(&mut self, removed_token_ids: &[u32]) {
1462        if removed_token_ids.is_empty() {
1463            return;
1464        }
1465
1466        let mut map = match std::mem::take(&mut self.tokens) {
1467            TokenMap::HashMap(map) => map,
1468            TokenMap::Fst(map) => {
1469                let mut new_map = HashMap::with_capacity(map.len());
1470                let mut stream = map.into_stream();
1471                while let Some((token, token_id)) = stream.next() {
1472                    new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
1473                }
1474
1475                new_map
1476            }
1477        };
1478
1479        map.retain(
1480            |_, token_id| match removed_token_ids.binary_search(token_id) {
1481                Ok(_) => false,
1482                Err(index) => {
1483                    *token_id -= index as u32;
1484                    true
1485                }
1486            },
1487        );
1488
1489        self.tokens = TokenMap::HashMap(map);
1490    }
1491
1492    pub fn next_id(&self) -> u32 {
1493        self.next_id
1494    }
1495
1496    pub(crate) fn memory_size(&self) -> usize {
1497        match &self.tokens {
1498            TokenMap::HashMap(map) => {
1499                self.total_length
1500                    + map.capacity()
1501                        * (std::mem::size_of::<String>()
1502                            + std::mem::size_of::<u32>()
1503                            + std::mem::size_of::<usize>())
1504            }
1505            TokenMap::Fst(map) => map.as_fst().size(),
1506        }
1507    }
1508}
1509
1510pub struct PostingListReader {
1511    reader: Arc<dyn IndexReader>,
1512
1513    // legacy format only
1514    offsets: Option<Vec<usize>>,
1515
1516    // from metadata for legacy format
1517    // from column for new format
1518    max_scores: Option<Vec<f32>>,
1519
1520    // new format only
1521    lengths: Option<Vec<u32>>,
1522
1523    has_position: bool,
1524    posting_tail_codec: PostingTailCodec,
1525    positions_layout: PositionsLayout,
1526
1527    index_cache: WeakLanceCache,
1528}
1529
1530#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1531enum PositionsLayout {
1532    None,
1533    LegacyPerDoc,
1534    SharedStream(PositionStreamCodec),
1535}
1536
1537impl std::fmt::Debug for PostingListReader {
1538    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1539        f.debug_struct("InvertedListReader")
1540            .field("offsets", &self.offsets)
1541            .field("max_scores", &self.max_scores)
1542            .finish()
1543    }
1544}
1545
1546impl DeepSizeOf for PostingListReader {
1547    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
1548        self.offsets.deep_size_of_children(context)
1549            + self.max_scores.deep_size_of_children(context)
1550            + self.lengths.deep_size_of_children(context)
1551    }
1552}
1553
1554impl PostingListReader {
1555    pub(crate) async fn try_new(
1556        reader: Arc<dyn IndexReader>,
1557        index_cache: &LanceCache,
1558    ) -> Result<Self> {
1559        let positions_layout = if reader.schema().field(COMPRESSED_POSITION_COL).is_some() {
1560            PositionsLayout::SharedStream(parse_shared_position_codec(&reader.schema().metadata)?)
1561        } else if reader.schema().field(POSITION_COL).is_some() {
1562            PositionsLayout::LegacyPerDoc
1563        } else {
1564            PositionsLayout::None
1565        };
1566        let posting_tail_codec = parse_posting_tail_codec(&reader.schema().metadata)?;
1567        let has_position = positions_layout != PositionsLayout::None;
1568        let (offsets, max_scores, lengths) = if reader.schema().field(POSTING_COL).is_none() {
1569            let (offsets, max_scores) = Self::load_metadata(reader.schema())?;
1570            (Some(offsets), max_scores, None)
1571        } else {
1572            let metadata = reader
1573                .read_range(0..reader.num_rows(), Some(&[MAX_SCORE_COL, LENGTH_COL]))
1574                .await?;
1575            let max_scores = metadata[MAX_SCORE_COL]
1576                .as_primitive::<Float32Type>()
1577                .values()
1578                .to_vec();
1579            let lengths = metadata[LENGTH_COL]
1580                .as_primitive::<UInt32Type>()
1581                .values()
1582                .to_vec();
1583            (None, Some(max_scores), Some(lengths))
1584        };
1585
1586        Ok(Self {
1587            reader,
1588            offsets,
1589            max_scores,
1590            lengths,
1591            has_position,
1592            posting_tail_codec,
1593            positions_layout,
1594            index_cache: WeakLanceCache::from(index_cache),
1595        })
1596    }
1597
1598    // for legacy format
1599    // returns the offsets and max scores
1600    fn load_metadata(
1601        schema: &lance_core::datatypes::Schema,
1602    ) -> Result<(Vec<usize>, Option<Vec<f32>>)> {
1603        let offsets = schema
1604            .metadata
1605            .get("offsets")
1606            .ok_or(Error::index("offsets not found in metadata".to_owned()))?;
1607        let offsets = serde_json::from_str(offsets)?;
1608
1609        let max_scores = schema
1610            .metadata
1611            .get("max_scores")
1612            .map(|max_scores| serde_json::from_str(max_scores))
1613            .transpose()?;
1614        Ok((offsets, max_scores))
1615    }
1616
1617    // the number of posting lists
1618    pub fn len(&self) -> usize {
1619        match self.offsets {
1620            Some(ref offsets) => offsets.len(),
1621            None => self.reader.num_rows(),
1622        }
1623    }
1624
1625    pub fn is_empty(&self) -> bool {
1626        self.len() == 0
1627    }
1628
1629    pub(crate) fn has_positions(&self) -> bool {
1630        self.has_position
1631    }
1632
1633    pub(crate) fn posting_tail_codec(&self) -> PostingTailCodec {
1634        self.posting_tail_codec
1635    }
1636
1637    pub(crate) fn posting_len(&self, token_id: u32) -> usize {
1638        let token_id = token_id as usize;
1639
1640        match self.offsets {
1641            Some(ref offsets) => {
1642                let next_offset = offsets
1643                    .get(token_id + 1)
1644                    .copied()
1645                    .unwrap_or(self.reader.num_rows());
1646                next_offset - offsets[token_id]
1647            }
1648            None => {
1649                if let Some(lengths) = &self.lengths {
1650                    lengths[token_id] as usize
1651                } else {
1652                    panic!("posting list reader is not initialized")
1653                }
1654            }
1655        }
1656    }
1657
1658    pub(crate) async fn posting_batch(
1659        &self,
1660        token_id: u32,
1661        with_position: bool,
1662    ) -> Result<RecordBatch> {
1663        if self.offsets.is_some() {
1664            self.posting_batch_legacy(token_id, with_position).await
1665        } else {
1666            let token_id = token_id as usize;
1667            let columns = if with_position {
1668                match self.positions_layout {
1669                    PositionsLayout::SharedStream(_) => {
1670                        vec![
1671                            POSTING_COL,
1672                            COMPRESSED_POSITION_COL,
1673                            POSITION_BLOCK_OFFSET_COL,
1674                        ]
1675                    }
1676                    PositionsLayout::LegacyPerDoc => vec![POSTING_COL, POSITION_COL],
1677                    PositionsLayout::None => vec![POSTING_COL],
1678                }
1679            } else {
1680                vec![POSTING_COL]
1681            };
1682            let batch = self
1683                .reader
1684                .read_range(token_id..token_id + 1, Some(&columns))
1685                .await?;
1686            Ok(batch)
1687        }
1688    }
1689
1690    async fn posting_batch_legacy(
1691        &self,
1692        token_id: u32,
1693        with_position: bool,
1694    ) -> Result<RecordBatch> {
1695        let mut columns = vec![ROW_ID, FREQUENCY_COL];
1696        if with_position {
1697            columns.push(POSITION_COL);
1698        }
1699
1700        let length = self.posting_len(token_id);
1701        let token_id = token_id as usize;
1702        let offset = self.offsets.as_ref().unwrap()[token_id];
1703        let batch = self
1704            .reader
1705            .read_range(offset..offset + length, Some(&columns))
1706            .await?;
1707        Ok(batch)
1708    }
1709
1710    #[instrument(level = "debug", skip(self, metrics))]
1711    pub(crate) async fn posting_list(
1712        &self,
1713        token_id: u32,
1714        is_phrase_query: bool,
1715        metrics: &dyn MetricsCollector,
1716    ) -> Result<PostingList> {
1717        let cache_key = PostingListKey { token_id };
1718        let mut posting = self
1719            .index_cache
1720            .get_or_insert_with_key(cache_key, || async move {
1721                metrics.record_part_load();
1722                info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="inverted", part_id=token_id);
1723                let batch = self.posting_batch(token_id, false).await?;
1724                self.posting_list_from_batch(&batch, token_id)
1725            })
1726            .await?
1727            .as_ref()
1728            .clone();
1729
1730        if is_phrase_query && !posting.has_position() {
1731            // hit the cache and when the cache was populated, the positions column was not loaded
1732            let positions = self.read_positions(token_id).await?;
1733            posting.set_positions(positions);
1734        }
1735
1736        Ok(posting)
1737    }
1738
1739    fn posting_list_from_batch_parts(
1740        batch: &RecordBatch,
1741        max_score: Option<f32>,
1742        length: Option<u32>,
1743        posting_tail_codec: PostingTailCodec,
1744        positions_layout: PositionsLayout,
1745    ) -> Result<PostingList> {
1746        let posting_list = PostingList::from_batch_with_tail_codec_and_positions_layout(
1747            batch,
1748            max_score,
1749            length,
1750            posting_tail_codec,
1751            positions_layout,
1752        )?;
1753        Ok(posting_list)
1754    }
1755
1756    pub(crate) fn posting_list_from_batch(
1757        &self,
1758        batch: &RecordBatch,
1759        token_id: u32,
1760    ) -> Result<PostingList> {
1761        Self::posting_list_from_batch_parts(
1762            batch,
1763            self.max_scores
1764                .as_ref()
1765                .map(|max_scores| max_scores[token_id as usize]),
1766            self.lengths
1767                .as_ref()
1768                .map(|lengths| lengths[token_id as usize]),
1769            self.posting_tail_codec,
1770            self.positions_layout,
1771        )
1772    }
1773
1774    fn build_prewarm_posting_lists(
1775        batch: RecordBatch,
1776        offsets: Option<Vec<usize>>,
1777        max_scores: Option<Vec<f32>>,
1778        lengths: Option<Vec<u32>>,
1779        posting_tail_codec: PostingTailCodec,
1780        positions_layout: PositionsLayout,
1781    ) -> Result<Vec<(u32, PostingList)>> {
1782        let token_count = if let Some(offsets) = offsets.as_ref() {
1783            offsets.len()
1784        } else if let Some(lengths) = lengths.as_ref() {
1785            lengths.len()
1786        } else {
1787            batch.num_rows()
1788        };
1789
1790        let mut posting_lists = Vec::with_capacity(token_count);
1791        for token_id in 0..token_count {
1792            let batch = if let Some(offsets) = offsets.as_ref() {
1793                let start = offsets[token_id];
1794                let end = if token_id + 1 < offsets.len() {
1795                    offsets[token_id + 1]
1796                } else {
1797                    batch.num_rows()
1798                };
1799                batch.slice(start, end - start)
1800            } else {
1801                batch.slice(token_id, 1)
1802            };
1803            let batch = batch.shrink_to_fit()?;
1804            let posting_list = Self::posting_list_from_batch_parts(
1805                &batch,
1806                max_scores.as_ref().map(|scores| scores[token_id]),
1807                lengths.as_ref().map(|lengths| lengths[token_id]),
1808                posting_tail_codec,
1809                positions_layout,
1810            )?;
1811            posting_lists.push((token_id as u32, posting_list));
1812        }
1813
1814        Ok(posting_lists)
1815    }
1816
1817    async fn prewarm_posting_lists(&self, with_position: bool) -> Result<()> {
1818        if with_position && !self.has_positions() {
1819            return Err(Error::invalid_input(
1820                "cannot prewarm positions for an inverted index that was built without positions; recreate the index with with_position=true".to_owned(),
1821            ));
1822        }
1823
1824        let read_batch_start = Instant::now();
1825        let batch = self.read_batch(with_position).await?;
1826        let read_batch_elapsed = read_batch_start.elapsed();
1827
1828        let legacy_layout = self.offsets.is_some();
1829        let offsets = self.offsets.clone();
1830        let max_scores = self.max_scores.clone();
1831        let lengths = self.lengths.clone();
1832        let posting_tail_codec = self.posting_tail_codec;
1833        let positions_layout = self.positions_layout;
1834        let populate_start = Instant::now();
1835        let posting_lists = spawn_blocking(move || {
1836            Self::build_prewarm_posting_lists(
1837                batch,
1838                offsets,
1839                max_scores,
1840                lengths,
1841                posting_tail_codec,
1842                positions_layout,
1843            )
1844        })
1845        .await
1846        .map_err(|err| {
1847            Error::internal(format!(
1848                "Failed to build prewarm posting lists in blocking task: {err}"
1849            ))
1850        })??;
1851        for (token_id, mut posting_list) in posting_lists {
1852            if with_position && let Some(positions) = posting_list.take_positions() {
1853                self.index_cache
1854                    .insert_with_key(&PositionKey { token_id }, Arc::new(Positions(positions)))
1855                    .await;
1856            }
1857            self.index_cache
1858                .insert_with_key(&PostingListKey { token_id }, Arc::new(posting_list))
1859                .await;
1860        }
1861        let populate_elapsed = populate_start.elapsed();
1862
1863        info!(
1864            legacy_layout,
1865            with_position,
1866            token_count = self.len(),
1867            read_batch_ms = read_batch_elapsed.as_secs_f64() * 1000.0,
1868            post_read_loop_ms = populate_elapsed.as_secs_f64() * 1000.0,
1869            "posting list prewarm timing"
1870        );
1871
1872        Ok(())
1873    }
1874
1875    pub(crate) async fn read_batch(&self, with_position: bool) -> Result<RecordBatch> {
1876        let columns = self.posting_columns(with_position);
1877        let batch = self
1878            .reader
1879            .read_range(0..self.reader.num_rows(), Some(&columns))
1880            .await?;
1881        Ok(batch)
1882    }
1883
1884    pub(crate) async fn read_all(
1885        &self,
1886        with_position: bool,
1887    ) -> Result<impl Iterator<Item = Result<PostingList>> + '_> {
1888        let batch = self.read_batch(with_position).await?;
1889        Ok((0..self.len()).map(move |i| {
1890            let token_id = i as u32;
1891            let range = self.posting_list_range(token_id);
1892            let batch = batch.slice(i, range.end - range.start);
1893            self.posting_list_from_batch(&batch, token_id)
1894        }))
1895    }
1896
1897    async fn read_positions(&self, token_id: u32) -> Result<CompressedPositionStorage> {
1898        let positions = self.index_cache.get_or_insert_with_key(PositionKey { token_id }, || async move {
1899            let positions = match self.positions_layout {
1900                PositionsLayout::None => {
1901                    return Err(Error::invalid_input(
1902                        "position is not found but required for phrase queries, try recreating the index with position".to_owned(),
1903                    ));
1904                }
1905                PositionsLayout::LegacyPerDoc => {
1906                    let batch = self
1907                        .reader
1908                        .read_range(self.posting_list_range(token_id), Some(&[POSITION_COL]))
1909                        .await
1910                        .map_err(|e| match e {
1911                            Error::Schema { .. } => Error::invalid_input("position is not found but required for phrase queries, try recreating the index with position".to_owned()),
1912                            e => e,
1913                        })?;
1914                    CompressedPositionStorage::LegacyPerDoc(
1915                        batch[POSITION_COL].as_list::<i32>().value(0).as_list::<i32>().clone(),
1916                    )
1917                }
1918                PositionsLayout::SharedStream(codec) => {
1919                    let batch = self
1920                        .reader
1921                        .read_range(
1922                            self.posting_list_range(token_id),
1923                            Some(&[COMPRESSED_POSITION_COL, POSITION_BLOCK_OFFSET_COL]),
1924                        )
1925                        .await
1926                        .map_err(|e| match e {
1927                            Error::Schema { .. } => Error::invalid_input("position is not found but required for phrase queries, try recreating the index with position".to_owned()),
1928                            e => e,
1929                        })?;
1930                    let bytes = batch[COMPRESSED_POSITION_COL]
1931                        .as_binary::<i64>()
1932                        .value(0)
1933                        .to_vec();
1934                    let block_offsets = batch[POSITION_BLOCK_OFFSET_COL]
1935                        .as_list::<i32>()
1936                        .value(0)
1937                        .as_primitive::<UInt32Type>()
1938                        .values()
1939                        .to_vec();
1940                    CompressedPositionStorage::SharedStream(SharedPositionStream::new(
1941                        codec,
1942                        block_offsets,
1943                        bytes,
1944                    ))
1945                }
1946            };
1947            Result::Ok(Positions(positions))
1948        }).await?;
1949        Ok(positions.0.clone())
1950    }
1951
1952    fn posting_list_range(&self, token_id: u32) -> Range<usize> {
1953        match self.offsets {
1954            Some(ref offsets) => {
1955                let offset = offsets[token_id as usize];
1956                let posting_len = self.posting_len(token_id);
1957                offset..offset + posting_len
1958            }
1959            None => {
1960                let token_id = token_id as usize;
1961                token_id..token_id + 1
1962            }
1963        }
1964    }
1965
1966    fn posting_columns(&self, with_position: bool) -> Vec<&'static str> {
1967        let mut base_columns = match self.offsets {
1968            Some(_) => vec![ROW_ID, FREQUENCY_COL],
1969            None => vec![POSTING_COL],
1970        };
1971        if with_position {
1972            match self.positions_layout {
1973                PositionsLayout::None => {}
1974                PositionsLayout::LegacyPerDoc => base_columns.push(POSITION_COL),
1975                PositionsLayout::SharedStream(_) => {
1976                    base_columns.push(COMPRESSED_POSITION_COL);
1977                    base_columns.push(POSITION_BLOCK_OFFSET_COL);
1978                }
1979            }
1980        }
1981        base_columns
1982    }
1983}
1984
1985/// New type just to allow Positions implement DeepSizeOf so it can be put
1986/// in the cache.
1987#[derive(Clone)]
1988pub struct Positions(CompressedPositionStorage);
1989
1990impl DeepSizeOf for Positions {
1991    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1992        match &self.0 {
1993            CompressedPositionStorage::LegacyPerDoc(positions) => {
1994                positions.get_buffer_memory_size()
1995            }
1996            CompressedPositionStorage::SharedStream(stream) => stream.size(),
1997        }
1998    }
1999}
2000
2001// Cache key implementations for type-safe cache access
2002#[derive(Debug, Clone)]
2003pub struct PostingListKey {
2004    pub token_id: u32,
2005}
2006
2007impl CacheKey for PostingListKey {
2008    type ValueType = PostingList;
2009
2010    fn key(&self) -> std::borrow::Cow<'_, str> {
2011        format!("postings-{}", self.token_id).into()
2012    }
2013
2014    fn type_name() -> &'static str {
2015        "PostingList"
2016    }
2017}
2018
2019#[derive(Debug, Clone)]
2020pub struct PositionKey {
2021    pub token_id: u32,
2022}
2023
2024impl CacheKey for PositionKey {
2025    type ValueType = Positions;
2026
2027    fn key(&self) -> std::borrow::Cow<'_, str> {
2028        format!("positions-{}", self.token_id).into()
2029    }
2030
2031    fn type_name() -> &'static str {
2032        "Position"
2033    }
2034}
2035
2036#[derive(Debug, Clone, PartialEq)]
2037pub enum CompressedPositionStorage {
2038    LegacyPerDoc(ListArray),
2039    SharedStream(SharedPositionStream),
2040}
2041
2042impl DeepSizeOf for CompressedPositionStorage {
2043    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
2044        match self {
2045            Self::LegacyPerDoc(positions) => positions.get_buffer_memory_size(),
2046            Self::SharedStream(stream) => stream.size(),
2047        }
2048    }
2049}
2050
2051#[derive(Debug, Clone, PartialEq, Eq, Default)]
2052pub struct SharedPositionStream {
2053    codec: PositionStreamCodec,
2054    block_offsets: Vec<u32>,
2055    bytes: Vec<u8>,
2056}
2057
2058impl SharedPositionStream {
2059    pub fn new(codec: PositionStreamCodec, block_offsets: Vec<u32>, bytes: Vec<u8>) -> Self {
2060        Self {
2061            codec,
2062            block_offsets,
2063            bytes,
2064        }
2065    }
2066
2067    pub fn codec(&self) -> PositionStreamCodec {
2068        self.codec
2069    }
2070
2071    pub fn block_count(&self) -> usize {
2072        self.block_offsets.len()
2073    }
2074
2075    pub fn block_range(&self, index: usize) -> Range<usize> {
2076        let start = self.block_offsets[index] as usize;
2077        let end = self
2078            .block_offsets
2079            .get(index + 1)
2080            .map(|offset| *offset as usize)
2081            .unwrap_or(self.bytes.len());
2082        start..end
2083    }
2084
2085    pub fn block(&self, index: usize) -> &[u8] {
2086        let range = self.block_range(index);
2087        &self.bytes[range]
2088    }
2089
2090    pub fn bytes(&self) -> &[u8] {
2091        &self.bytes
2092    }
2093
2094    pub fn block_offsets(&self) -> &[u32] {
2095        &self.block_offsets
2096    }
2097
2098    pub fn size(&self) -> usize {
2099        self.block_offsets.capacity() * std::mem::size_of::<u32>() + self.bytes.capacity()
2100    }
2101}
2102
2103#[derive(Debug, Clone, DeepSizeOf)]
2104pub enum PostingList {
2105    Plain(PlainPostingList),
2106    Compressed(CompressedPostingList),
2107}
2108
2109impl PostingList {
2110    pub fn from_batch(
2111        batch: &RecordBatch,
2112        max_score: Option<f32>,
2113        length: Option<u32>,
2114    ) -> Result<Self> {
2115        let posting_tail_codec = parse_posting_tail_codec(batch.schema_ref().metadata())?;
2116        Self::from_batch_with_tail_codec(batch, max_score, length, posting_tail_codec)
2117    }
2118
2119    pub fn from_batch_with_tail_codec(
2120        batch: &RecordBatch,
2121        max_score: Option<f32>,
2122        length: Option<u32>,
2123        posting_tail_codec: PostingTailCodec,
2124    ) -> Result<Self> {
2125        let positions_layout = if batch.column_by_name(COMPRESSED_POSITION_COL).is_some() {
2126            PositionsLayout::SharedStream(parse_shared_position_codec(
2127                batch.schema_ref().metadata(),
2128            )?)
2129        } else if batch.column_by_name(POSITION_COL).is_some() {
2130            PositionsLayout::LegacyPerDoc
2131        } else {
2132            PositionsLayout::None
2133        };
2134        Self::from_batch_with_tail_codec_and_positions_layout(
2135            batch,
2136            max_score,
2137            length,
2138            posting_tail_codec,
2139            positions_layout,
2140        )
2141    }
2142
2143    fn from_batch_with_tail_codec_and_positions_layout(
2144        batch: &RecordBatch,
2145        max_score: Option<f32>,
2146        length: Option<u32>,
2147        posting_tail_codec: PostingTailCodec,
2148        positions_layout: PositionsLayout,
2149    ) -> Result<Self> {
2150        match batch.column_by_name(POSTING_COL) {
2151            Some(_) => {
2152                debug_assert!(max_score.is_some() && length.is_some());
2153                let shared_position_codec = match positions_layout {
2154                    PositionsLayout::SharedStream(codec) => Some(codec),
2155                    _ => None,
2156                };
2157                let posting = CompressedPostingList::from_batch(
2158                    batch,
2159                    max_score.unwrap(),
2160                    length.unwrap(),
2161                    posting_tail_codec,
2162                    shared_position_codec,
2163                );
2164                Ok(Self::Compressed(posting))
2165            }
2166            None => {
2167                let posting = PlainPostingList::from_batch(batch, max_score);
2168                Ok(Self::Plain(posting))
2169            }
2170        }
2171    }
2172
2173    pub fn iter(&self) -> PostingListIterator<'_> {
2174        PostingListIterator::new(self)
2175    }
2176
2177    pub fn has_position(&self) -> bool {
2178        match self {
2179            Self::Plain(posting) => posting.positions.is_some(),
2180            Self::Compressed(posting) => posting.positions.is_some(),
2181        }
2182    }
2183
2184    pub fn set_positions(&mut self, positions: CompressedPositionStorage) {
2185        match self {
2186            Self::Plain(posting) => match positions {
2187                CompressedPositionStorage::LegacyPerDoc(positions) => {
2188                    posting.positions = Some(positions)
2189                }
2190                CompressedPositionStorage::SharedStream(_) => {
2191                    unreachable!("shared position stream is not supported for plain postings")
2192                }
2193            },
2194            Self::Compressed(posting) => {
2195                posting.positions = Some(positions);
2196            }
2197        }
2198    }
2199
2200    pub fn take_positions(&mut self) -> Option<CompressedPositionStorage> {
2201        match self {
2202            Self::Plain(posting) => posting
2203                .positions
2204                .take()
2205                .map(CompressedPositionStorage::LegacyPerDoc),
2206            Self::Compressed(posting) => posting.positions.take(),
2207        }
2208    }
2209
2210    pub fn max_score(&self) -> Option<f32> {
2211        match self {
2212            Self::Plain(posting) => posting.max_score,
2213            Self::Compressed(posting) => Some(posting.max_score),
2214        }
2215    }
2216
2217    pub fn len(&self) -> usize {
2218        match self {
2219            Self::Plain(posting) => posting.len(),
2220            Self::Compressed(posting) => posting.length as usize,
2221        }
2222    }
2223
2224    pub fn is_empty(&self) -> bool {
2225        self.len() == 0
2226    }
2227
2228    pub fn into_builder(self, docs: &DocSet) -> PostingListBuilder {
2229        let posting_tail_codec = match &self {
2230            Self::Plain(_) => PostingTailCodec::Fixed32,
2231            Self::Compressed(posting) => posting.posting_tail_codec,
2232        };
2233        let mut builder = PostingListBuilder::new_with_posting_tail_codec(
2234            self.has_position(),
2235            posting_tail_codec,
2236        );
2237        match self {
2238            // legacy format
2239            Self::Plain(posting) => {
2240                // convert the posting list to the new format:
2241                // 1. map row ids to doc ids
2242                // 2. sort the posting list by doc ids
2243                struct Item {
2244                    doc_id: u32,
2245                    positions: PositionRecorder,
2246                }
2247                let doc_ids = docs
2248                    .row_ids
2249                    .iter()
2250                    .enumerate()
2251                    .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
2252                    .collect::<HashMap<_, _>>();
2253                let mut items = Vec::with_capacity(posting.len());
2254                for (row_id, freq, positions) in posting.iter() {
2255                    let freq = freq as u32;
2256                    let positions = match positions {
2257                        Some(positions) => {
2258                            PositionRecorder::Position(positions.collect::<Vec<_>>().into())
2259                        }
2260                        None => PositionRecorder::Count(freq),
2261                    };
2262                    items.push(Item {
2263                        doc_id: doc_ids[&row_id],
2264                        positions,
2265                    });
2266                }
2267                items.sort_unstable_by_key(|item| item.doc_id);
2268                for item in items {
2269                    builder.add(item.doc_id, item.positions);
2270                }
2271            }
2272            Self::Compressed(posting) => {
2273                posting.iter().for_each(|(doc_id, freq, positions)| {
2274                    let positions = match positions {
2275                        Some(positions) => {
2276                            PositionRecorder::Position(positions.collect::<Vec<_>>().into())
2277                        }
2278                        None => PositionRecorder::Count(freq),
2279                    };
2280                    builder.add(doc_id, positions);
2281                });
2282            }
2283        }
2284        builder
2285    }
2286}
2287
2288#[derive(Debug, PartialEq, Clone)]
2289pub struct PlainPostingList {
2290    pub row_ids: ScalarBuffer<u64>,
2291    pub frequencies: ScalarBuffer<f32>,
2292    pub max_score: Option<f32>,
2293    pub positions: Option<ListArray>, // List of Int32
2294}
2295
2296impl DeepSizeOf for PlainPostingList {
2297    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
2298        self.row_ids.len() * std::mem::size_of::<u64>()
2299            + self.frequencies.len() * std::mem::size_of::<u32>()
2300            + self
2301                .positions
2302                .as_ref()
2303                .map(Array::get_buffer_memory_size)
2304                .unwrap_or(0)
2305    }
2306}
2307
2308impl PlainPostingList {
2309    pub fn new(
2310        row_ids: ScalarBuffer<u64>,
2311        frequencies: ScalarBuffer<f32>,
2312        max_score: Option<f32>,
2313        positions: Option<ListArray>,
2314    ) -> Self {
2315        Self {
2316            row_ids,
2317            frequencies,
2318            max_score,
2319            positions,
2320        }
2321    }
2322
2323    pub fn from_batch(batch: &RecordBatch, max_score: Option<f32>) -> Self {
2324        let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().values().clone();
2325        let frequencies = batch[FREQUENCY_COL]
2326            .as_primitive::<Float32Type>()
2327            .values()
2328            .clone();
2329        let positions = batch
2330            .column_by_name(POSITION_COL)
2331            .map(|col| col.as_list::<i32>().clone());
2332
2333        Self::new(row_ids, frequencies, max_score, positions)
2334    }
2335
2336    pub fn len(&self) -> usize {
2337        self.row_ids.len()
2338    }
2339
2340    pub fn is_empty(&self) -> bool {
2341        self.len() == 0
2342    }
2343
2344    pub fn iter(&self) -> PlainPostingListIterator<'_> {
2345        Box::new(
2346            self.row_ids
2347                .iter()
2348                .zip(self.frequencies.iter())
2349                .enumerate()
2350                .map(|(idx, (doc_id, freq))| {
2351                    (
2352                        *doc_id,
2353                        *freq,
2354                        self.positions.as_ref().map(|p| {
2355                            let start = p.value_offsets()[idx] as usize;
2356                            let end = p.value_offsets()[idx + 1] as usize;
2357                            Box::new(
2358                                p.values().as_primitive::<Int32Type>().values()[start..end]
2359                                    .iter()
2360                                    .map(|pos| *pos as u32),
2361                            ) as _
2362                        }),
2363                    )
2364                }),
2365        )
2366    }
2367
2368    #[inline]
2369    pub fn doc(&self, i: usize) -> LocatedDocInfo {
2370        LocatedDocInfo::new(self.row_ids[i], self.frequencies[i])
2371    }
2372
2373    pub fn positions(&self, index: usize) -> Option<Arc<dyn Array>> {
2374        self.positions
2375            .as_ref()
2376            .map(|positions| positions.value(index))
2377    }
2378
2379    pub fn max_score(&self) -> Option<f32> {
2380        self.max_score
2381    }
2382
2383    pub fn row_id(&self, i: usize) -> u64 {
2384        self.row_ids[i]
2385    }
2386}
2387
2388#[derive(Debug, PartialEq, Clone)]
2389pub struct CompressedPostingList {
2390    pub max_score: f32,
2391    pub length: u32,
2392    // each binary is a block of compressed data
2393    // that contains `BLOCK_SIZE` doc ids and then `BLOCK_SIZE` frequencies
2394    pub blocks: LargeBinaryArray,
2395    pub posting_tail_codec: PostingTailCodec,
2396    pub positions: Option<CompressedPositionStorage>,
2397}
2398
2399impl DeepSizeOf for CompressedPostingList {
2400    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
2401        self.blocks.get_buffer_memory_size()
2402            + self
2403                .positions
2404                .as_ref()
2405                .map(|positions| match positions {
2406                    CompressedPositionStorage::LegacyPerDoc(positions) => {
2407                        positions.get_buffer_memory_size()
2408                    }
2409                    CompressedPositionStorage::SharedStream(stream) => stream.size(),
2410                })
2411                .unwrap_or(0)
2412    }
2413}
2414
2415impl CompressedPostingList {
2416    pub fn new(
2417        blocks: LargeBinaryArray,
2418        max_score: f32,
2419        length: u32,
2420        posting_tail_codec: PostingTailCodec,
2421        positions: Option<CompressedPositionStorage>,
2422    ) -> Self {
2423        Self {
2424            max_score,
2425            length,
2426            blocks,
2427            posting_tail_codec,
2428            positions,
2429        }
2430    }
2431
2432    pub fn from_batch(
2433        batch: &RecordBatch,
2434        max_score: f32,
2435        length: u32,
2436        posting_tail_codec: PostingTailCodec,
2437        shared_position_codec: Option<PositionStreamCodec>,
2438    ) -> Self {
2439        debug_assert_eq!(batch.num_rows(), 1);
2440        let blocks = batch[POSTING_COL]
2441            .as_list::<i32>()
2442            .value(0)
2443            .as_binary::<i64>()
2444            .clone();
2445        let positions = if let Some(col) = batch.column_by_name(COMPRESSED_POSITION_COL) {
2446            let bytes = col.as_binary::<i64>().value(0).to_vec();
2447            let block_offsets = batch[POSITION_BLOCK_OFFSET_COL]
2448                .as_list::<i32>()
2449                .value(0)
2450                .as_primitive::<UInt32Type>()
2451                .values()
2452                .to_vec();
2453            let codec = shared_position_codec.unwrap_or_else(|| {
2454                parse_shared_position_codec(batch.schema_ref().metadata())
2455                    .expect("shared position stream codec metadata should be valid")
2456            });
2457            Some(CompressedPositionStorage::SharedStream(
2458                SharedPositionStream::new(codec, block_offsets, bytes),
2459            ))
2460        } else {
2461            batch.column_by_name(POSITION_COL).map(|col| {
2462                CompressedPositionStorage::LegacyPerDoc(
2463                    col.as_list::<i32>().value(0).as_list::<i32>().clone(),
2464                )
2465            })
2466        };
2467
2468        Self {
2469            max_score,
2470            length,
2471            blocks,
2472            posting_tail_codec,
2473            positions,
2474        }
2475    }
2476
2477    pub fn iter(&self) -> CompressedPostingListIterator {
2478        CompressedPostingListIterator::new(
2479            self.length as usize,
2480            self.blocks.clone(),
2481            self.posting_tail_codec,
2482            self.positions.clone(),
2483        )
2484    }
2485
2486    pub fn block_max_score(&self, block_idx: usize) -> f32 {
2487        let block = self.blocks.value(block_idx);
2488        block[0..4].try_into().map(f32::from_le_bytes).unwrap()
2489    }
2490
2491    pub fn block_least_doc_id(&self, block_idx: usize) -> u32 {
2492        let block = self.blocks.value(block_idx);
2493        let remainder = self.length as usize % BLOCK_SIZE;
2494        let is_remainder_block = remainder > 0 && block_idx + 1 == self.blocks.len();
2495        if is_remainder_block {
2496            super::encoding::read_posting_tail_first_doc(block, self.posting_tail_codec)
2497        } else {
2498            block[4..8].try_into().map(u32::from_le_bytes).unwrap()
2499        }
2500    }
2501}
2502
2503#[derive(Debug, Clone, PartialEq, Eq, Default)]
2504struct EncodedBlocks {
2505    offsets: Vec<u32>,
2506    bytes: Vec<u8>,
2507}
2508
2509impl EncodedBlocks {
2510    fn len(&self) -> usize {
2511        self.offsets.len()
2512    }
2513
2514    fn size(&self) -> usize {
2515        self.offsets.capacity() * std::mem::size_of::<u32>() + self.bytes.capacity()
2516    }
2517
2518    fn push_full_block(&mut self, doc_ids: &[u32], frequencies: &[u32]) -> Result<usize> {
2519        let start = self.bytes.len();
2520        self.offsets.push(start as u32);
2521        super::encoding::encode_full_posting_block_into(doc_ids, frequencies, &mut self.bytes)?;
2522        Ok(self.bytes.len() - start)
2523    }
2524
2525    fn block(&self, index: usize) -> &[u8] {
2526        let (start, end) = self.block_range(index);
2527        &self.bytes[start..end]
2528    }
2529
2530    fn block_range(&self, index: usize) -> (usize, usize) {
2531        let start = self.offsets[index] as usize;
2532        let end = self
2533            .offsets
2534            .get(index + 1)
2535            .map(|offset| *offset as usize)
2536            .unwrap_or(self.bytes.len());
2537        (start, end)
2538    }
2539
2540    fn set_block_score(&mut self, index: usize, score: f32) {
2541        let (start, _) = self.block_range(index);
2542        self.bytes[start..start + 4].copy_from_slice(&score.to_le_bytes());
2543    }
2544
2545    fn append_remainder_block_with_codec(
2546        &mut self,
2547        doc_ids: &[u32],
2548        frequencies: &[u32],
2549        codec: PostingTailCodec,
2550    ) -> Result<()> {
2551        self.offsets.push(self.bytes.len() as u32);
2552        super::encoding::encode_remainder_posting_block_into(
2553            doc_ids,
2554            frequencies,
2555            codec,
2556            &mut self.bytes,
2557        )
2558    }
2559
2560    fn into_array(mut self) -> LargeBinaryArray {
2561        let mut offsets = Vec::with_capacity(self.offsets.len() + 1);
2562        offsets.extend(self.offsets.into_iter().map(i64::from));
2563        offsets.push(self.bytes.len() as i64);
2564        LargeBinaryArray::new(
2565            OffsetBuffer::new(ScalarBuffer::from(offsets)),
2566            Buffer::from_vec(std::mem::take(&mut self.bytes)),
2567            None,
2568        )
2569    }
2570
2571    fn iter(&self) -> impl Iterator<Item = &[u8]> {
2572        (0..self.len()).map(|index| self.block(index))
2573    }
2574}
2575
2576#[derive(Debug, Clone, PartialEq, Eq, Default)]
2577struct EncodedPositionBlocks {
2578    offsets: Vec<u32>,
2579    bytes: Vec<u8>,
2580}
2581
2582impl EncodedPositionBlocks {
2583    fn size(&self) -> usize {
2584        self.offsets.capacity() * std::mem::size_of::<u32>() + self.bytes.capacity()
2585    }
2586
2587    fn block(&self, index: usize) -> &[u8] {
2588        let start = self.offsets[index] as usize;
2589        let end = self
2590            .offsets
2591            .get(index + 1)
2592            .map(|offset| *offset as usize)
2593            .unwrap_or(self.bytes.len());
2594        &self.bytes[start..end]
2595    }
2596
2597    fn push_encoded_block(&mut self, block: &[u8]) -> usize {
2598        let start = self.bytes.len();
2599        self.offsets.push(start as u32);
2600        self.bytes.extend_from_slice(block);
2601        self.bytes.len() - start
2602    }
2603
2604    fn into_stream(self) -> SharedPositionStream {
2605        SharedPositionStream::new(PositionStreamCodec::PackedDelta, self.offsets, self.bytes)
2606    }
2607}
2608
2609#[derive(Debug)]
2610pub struct PostingListBuilder {
2611    with_positions: bool,
2612    posting_tail_codec: PostingTailCodec,
2613    encoded_blocks: Option<Box<EncodedBlocks>>,
2614    encoded_position_blocks: Option<Box<EncodedPositionBlocks>>,
2615    tail_entries: Vec<RawDocInfo>,
2616    tail_positions: PositionBlockBuilder,
2617    open_doc_id: Option<u32>,
2618    open_doc_frequency: u32,
2619    open_doc_last_position: Option<u32>,
2620    memory_size_bytes: u32,
2621    len: u32,
2622}
2623
2624pub(super) struct PostingListBatchBuilder {
2625    schema: SchemaRef,
2626    postings: ListBuilder<LargeBinaryBuilder>,
2627    max_scores: Float32Builder,
2628    lengths: UInt32Builder,
2629    positions: BatchPositionsBuilder,
2630    len: usize,
2631}
2632
2633enum BatchPositionsBuilder {
2634    None,
2635    Legacy(ListBuilder<ListBuilder<LargeBinaryBuilder>>),
2636    Shared {
2637        bytes: LargeBinaryBuilder,
2638        block_offsets: ListBuilder<UInt32Builder>,
2639    },
2640}
2641
2642struct PostingListParts<'a> {
2643    with_positions: bool,
2644    posting_tail_codec: PostingTailCodec,
2645    length: usize,
2646    encoded_blocks: EncodedBlocks,
2647    encoded_position_blocks: EncodedPositionBlocks,
2648    tail_entries: &'a [RawDocInfo],
2649    tail_position_block: Option<Vec<u8>>,
2650}
2651
2652impl PostingListBatchBuilder {
2653    pub fn new(
2654        schema: SchemaRef,
2655        with_positions: bool,
2656        format_version: InvertedListFormatVersion,
2657        capacity: usize,
2658    ) -> Self {
2659        let positions = if !with_positions {
2660            BatchPositionsBuilder::None
2661        } else if format_version.uses_shared_position_stream() {
2662            BatchPositionsBuilder::Shared {
2663                bytes: LargeBinaryBuilder::with_capacity(capacity, 0),
2664                block_offsets: ListBuilder::with_capacity(UInt32Builder::new(), capacity),
2665            }
2666        } else {
2667            BatchPositionsBuilder::Legacy(ListBuilder::with_capacity(
2668                ListBuilder::new(LargeBinaryBuilder::new()),
2669                capacity,
2670            ))
2671        };
2672        Self {
2673            schema,
2674            postings: ListBuilder::with_capacity(LargeBinaryBuilder::new(), capacity),
2675            max_scores: Float32Builder::with_capacity(capacity),
2676            lengths: UInt32Builder::with_capacity(capacity),
2677            positions,
2678            len: 0,
2679        }
2680    }
2681
2682    pub fn len(&self) -> usize {
2683        self.len
2684    }
2685
2686    pub fn is_empty(&self) -> bool {
2687        self.len == 0
2688    }
2689
2690    fn append(
2691        &mut self,
2692        compressed: LargeBinaryArray,
2693        max_score: f32,
2694        length: u32,
2695        positions: Option<&CompressedPositionStorage>,
2696    ) -> Result<()> {
2697        {
2698            let values = self.postings.values();
2699            for index in 0..compressed.len() {
2700                values.append_value(compressed.value(index));
2701            }
2702        }
2703        self.postings.append(true);
2704        self.max_scores.append_value(max_score);
2705        self.lengths.append_value(length);
2706
2707        match &mut self.positions {
2708            BatchPositionsBuilder::None => {}
2709            BatchPositionsBuilder::Shared {
2710                bytes,
2711                block_offsets,
2712            } => {
2713                let positions = positions.ok_or_else(|| {
2714                    Error::index(format!(
2715                        "positions builder missing position data for posting length {}",
2716                        length
2717                    ))
2718                })?;
2719                let CompressedPositionStorage::SharedStream(positions) = positions else {
2720                    return Err(Error::index(
2721                        "shared positions builder received legacy positions".to_owned(),
2722                    ));
2723                };
2724                bytes.append_value(positions.bytes());
2725                let offsets_builder = block_offsets.values();
2726                for &offset in positions.block_offsets() {
2727                    offsets_builder.append_value(offset);
2728                }
2729                block_offsets.append(true);
2730            }
2731            BatchPositionsBuilder::Legacy(position_lists) => {
2732                let positions = positions.ok_or_else(|| {
2733                    Error::index(format!(
2734                        "positions builder missing position data for posting length {}",
2735                        length
2736                    ))
2737                })?;
2738                let CompressedPositionStorage::LegacyPerDoc(positions) = positions else {
2739                    return Err(Error::index(
2740                        "legacy positions builder received shared position stream".to_owned(),
2741                    ));
2742                };
2743                let docs_builder = position_lists.values();
2744                for doc_idx in 0..positions.len() {
2745                    let doc_positions = positions.value(doc_idx);
2746                    let compressed_positions = doc_positions.as_binary::<i64>();
2747                    for block_idx in 0..compressed_positions.len() {
2748                        docs_builder
2749                            .values()
2750                            .append_value(compressed_positions.value(block_idx));
2751                    }
2752                    docs_builder.append(true);
2753                }
2754                position_lists.append(true);
2755            }
2756        }
2757
2758        self.len += 1;
2759        Ok(())
2760    }
2761
2762    pub fn finish(&mut self) -> Result<RecordBatch> {
2763        let mut columns = vec![
2764            Arc::new(self.postings.finish()) as ArrayRef,
2765            Arc::new(self.max_scores.finish()) as ArrayRef,
2766            Arc::new(self.lengths.finish()) as ArrayRef,
2767        ];
2768        match &mut self.positions {
2769            BatchPositionsBuilder::None => {}
2770            BatchPositionsBuilder::Legacy(position_lists) => {
2771                columns.push(Arc::new(position_lists.finish()) as ArrayRef);
2772            }
2773            BatchPositionsBuilder::Shared {
2774                bytes,
2775                block_offsets,
2776            } => {
2777                columns.push(Arc::new(bytes.finish()) as ArrayRef);
2778                columns.push(Arc::new(block_offsets.finish()) as ArrayRef);
2779            }
2780        }
2781        self.len = 0;
2782        RecordBatch::try_new(self.schema.clone(), columns).map_err(Error::from)
2783    }
2784}
2785
2786impl PostingListBuilder {
2787    pub fn size(&self) -> u64 {
2788        self.memory_size_bytes as u64
2789    }
2790
2791    pub fn has_positions(&self) -> bool {
2792        self.with_positions
2793    }
2794
2795    pub fn new(with_position: bool) -> Self {
2796        Self::new_with_posting_tail_codec(
2797            with_position,
2798            current_fts_format_version().posting_tail_codec(),
2799        )
2800    }
2801
2802    pub fn new_with_posting_tail_codec(
2803        with_position: bool,
2804        posting_tail_codec: PostingTailCodec,
2805    ) -> Self {
2806        Self {
2807            with_positions: with_position,
2808            posting_tail_codec,
2809            encoded_blocks: None,
2810            encoded_position_blocks: None,
2811            tail_entries: Vec::new(),
2812            tail_positions: PositionBlockBuilder::default(),
2813            open_doc_id: None,
2814            open_doc_frequency: 0,
2815            open_doc_last_position: None,
2816            len: 0,
2817            memory_size_bytes: 0,
2818        }
2819    }
2820
2821    pub fn len(&self) -> usize {
2822        self.len as usize
2823    }
2824
2825    pub fn is_empty(&self) -> bool {
2826        self.len == 0
2827    }
2828
2829    pub fn iter(&self) -> std::vec::IntoIter<(u32, u32, Option<Vec<u32>>)> {
2830        self.collect_entries().into_iter()
2831    }
2832
2833    pub fn for_each_entry<E>(
2834        &self,
2835        mut visit: impl FnMut(u32, u32, Option<Vec<u32>>) -> std::result::Result<(), E>,
2836    ) -> std::result::Result<(), E> {
2837        let mut doc_ids = Vec::with_capacity(BLOCK_SIZE);
2838        let mut frequencies = Vec::with_capacity(BLOCK_SIZE);
2839        let mut decoded_positions = Vec::new();
2840        let mut position_block_index = 0usize;
2841
2842        if let Some(encoded_blocks) = self.encoded_blocks.as_deref() {
2843            for block in encoded_blocks.iter() {
2844                doc_ids.clear();
2845                frequencies.clear();
2846                super::encoding::decode_full_posting_block(block, &mut doc_ids, &mut frequencies);
2847                decoded_positions.clear();
2848                if self.with_positions {
2849                    let position_blocks = self
2850                        .encoded_position_blocks
2851                        .as_deref()
2852                        .expect("positions must exist for posting list");
2853                    super::encoding::decode_position_stream_block(
2854                        position_blocks.block(position_block_index),
2855                        &frequencies,
2856                        PositionStreamCodec::PackedDelta,
2857                        &mut decoded_positions,
2858                    )
2859                    .expect("position stream decoding should succeed");
2860                    position_block_index += 1;
2861                }
2862                let mut offset = 0usize;
2863                for (doc_id, frequency) in doc_ids.iter().copied().zip(frequencies.iter().copied())
2864                {
2865                    let positions = self.with_positions.then(|| {
2866                        let end = offset + frequency as usize;
2867                        let doc_positions = decoded_positions[offset..end].to_vec();
2868                        offset = end;
2869                        doc_positions
2870                    });
2871                    visit(doc_id, frequency, positions)?;
2872                }
2873            }
2874        }
2875
2876        let mut decoded_tail_positions = Vec::new();
2877        if self.with_positions && !self.tail_entries.is_empty() {
2878            let tail_frequencies = self
2879                .tail_entries
2880                .iter()
2881                .map(|entry| entry.frequency)
2882                .collect::<Vec<_>>();
2883            self.tail_positions
2884                .decode_into(tail_frequencies.as_slice(), &mut decoded_tail_positions)
2885                .expect("tail position stream decoding should succeed");
2886        }
2887        let mut tail_offset = 0usize;
2888        for entry in &self.tail_entries {
2889            let positions = self.with_positions.then(|| {
2890                let end = tail_offset + entry.frequency as usize;
2891                let doc_positions = decoded_tail_positions[tail_offset..end].to_vec();
2892                tail_offset = end;
2893                doc_positions
2894            });
2895            visit(entry.doc_id, entry.frequency, positions)?;
2896        }
2897
2898        Ok(())
2899    }
2900
2901    pub fn add(&mut self, doc_id: u32, term_positions: PositionRecorder) {
2902        debug_assert!(
2903            self.open_doc_id.is_none(),
2904            "cannot add closed doc while a positions doc is still open"
2905        );
2906        let tail_entries_capacity_before = self.tail_entries.capacity();
2907        self.tail_entries
2908            .push(RawDocInfo::new(doc_id, term_positions.len()));
2909        let tail_entries_capacity_after = self.tail_entries.capacity();
2910        if tail_entries_capacity_after > tail_entries_capacity_before {
2911            self.add_memory_bytes(
2912                (tail_entries_capacity_after - tail_entries_capacity_before)
2913                    * std::mem::size_of::<RawDocInfo>(),
2914            );
2915        }
2916        if let PositionRecorder::Position(positions_in_doc) = term_positions {
2917            debug_assert!(self.with_positions);
2918            let old_size = self.tail_positions.size();
2919            self.tail_positions
2920                .append_doc_positions(positions_in_doc.as_slice())
2921                .expect("position stream encoding should succeed");
2922            self.adjust_tail_positions_size(old_size);
2923        }
2924        self.len += 1;
2925
2926        if self.tail_entries.len() == BLOCK_SIZE {
2927            self.flush_tail_block()
2928                .expect("posting list block compression should succeed");
2929        }
2930    }
2931
2932    pub fn add_occurrence(&mut self, doc_id: u32, position: u32) -> Result<bool> {
2933        if !self.with_positions {
2934            return Err(Error::index(
2935                "cannot append streamed positions to a posting list without positions".to_owned(),
2936            ));
2937        }
2938
2939        match self.open_doc_id {
2940            Some(open_doc_id) if open_doc_id == doc_id => {
2941                let old_size = self.tail_positions.size();
2942                self.tail_positions
2943                    .append_position(position, self.open_doc_last_position)?;
2944                self.adjust_tail_positions_size(old_size);
2945                self.open_doc_frequency += 1;
2946                self.open_doc_last_position = Some(position);
2947                Ok(false)
2948            }
2949            Some(open_doc_id) => Err(Error::index(format!(
2950                "posting list received doc {} before finishing open doc {}",
2951                doc_id, open_doc_id
2952            ))),
2953            None => {
2954                let old_size = self.tail_positions.size();
2955                self.tail_positions.append_position(position, None)?;
2956                self.adjust_tail_positions_size(old_size);
2957                self.open_doc_id = Some(doc_id);
2958                self.open_doc_frequency = 1;
2959                self.open_doc_last_position = Some(position);
2960                self.len += 1;
2961                Ok(true)
2962            }
2963        }
2964    }
2965
2966    pub fn finish_open_doc(&mut self, doc_id: u32) -> Result<()> {
2967        if !self.with_positions {
2968            return Ok(());
2969        }
2970        match self.open_doc_id {
2971            Some(open_doc_id) if open_doc_id == doc_id => {
2972                let tail_entries_capacity_before = self.tail_entries.capacity();
2973                self.tail_entries
2974                    .push(RawDocInfo::new(doc_id, self.open_doc_frequency));
2975                let tail_entries_capacity_after = self.tail_entries.capacity();
2976                if tail_entries_capacity_after > tail_entries_capacity_before {
2977                    self.add_memory_bytes(
2978                        (tail_entries_capacity_after - tail_entries_capacity_before)
2979                            * std::mem::size_of::<RawDocInfo>(),
2980                    );
2981                }
2982                self.open_doc_id = None;
2983                self.open_doc_frequency = 0;
2984                self.open_doc_last_position = None;
2985                if self.tail_entries.len() == BLOCK_SIZE {
2986                    self.flush_tail_block()?;
2987                }
2988                Ok(())
2989            }
2990            Some(open_doc_id) => Err(Error::index(format!(
2991                "attempted to finish doc {} while doc {} is still open",
2992                doc_id, open_doc_id
2993            ))),
2994            None => Ok(()),
2995        }
2996    }
2997
2998    fn collect_entries(&self) -> Vec<(u32, u32, Option<Vec<u32>>)> {
2999        let mut entries = Vec::with_capacity(self.len());
3000        self.for_each_entry(|doc_id, frequency, positions| {
3001            entries.push((doc_id, frequency, positions));
3002            Ok::<(), ()>(())
3003        })
3004        .expect("collecting posting list entries should not fail");
3005        entries
3006    }
3007
3008    fn encoded_blocks_mut(&mut self) -> &mut EncodedBlocks {
3009        if self.encoded_blocks.is_none() {
3010            self.encoded_blocks = Some(Box::default());
3011            self.add_memory_bytes(std::mem::size_of::<EncodedBlocks>());
3012        }
3013        self.encoded_blocks
3014            .as_deref_mut()
3015            .expect("encoded blocks must exist")
3016    }
3017
3018    fn encoded_position_blocks_mut(&mut self) -> &mut EncodedPositionBlocks {
3019        if self.encoded_position_blocks.is_none() {
3020            self.encoded_position_blocks = Some(Box::default());
3021            self.add_memory_bytes(std::mem::size_of::<EncodedPositionBlocks>());
3022        }
3023        self.encoded_position_blocks
3024            .as_deref_mut()
3025            .expect("encoded position blocks must exist")
3026    }
3027
3028    fn flush_tail_block(&mut self) -> Result<()> {
3029        if self.tail_entries.is_empty() {
3030            return Ok(());
3031        }
3032        debug_assert!(
3033            self.open_doc_id.is_none(),
3034            "cannot flush a posting block while a document is still open"
3035        );
3036        debug_assert_eq!(self.tail_entries.len(), BLOCK_SIZE);
3037        let mut doc_ids = [0u32; BLOCK_SIZE];
3038        let mut frequencies = [0u32; BLOCK_SIZE];
3039        for (index, entry) in self.tail_entries.iter().enumerate() {
3040            doc_ids[index] = entry.doc_id;
3041            frequencies[index] = entry.frequency;
3042        }
3043        let encoded_blocks_size_before = self
3044            .encoded_blocks
3045            .as_ref()
3046            .map(|encoded_blocks| encoded_blocks.size())
3047            .unwrap_or(0usize);
3048        self.encoded_blocks_mut()
3049            .push_full_block(&doc_ids, &frequencies)?;
3050        let encoded_blocks_size_after = self
3051            .encoded_blocks
3052            .as_ref()
3053            .map(|encoded_blocks| encoded_blocks.size())
3054            .unwrap_or(0usize);
3055        if encoded_blocks_size_after > encoded_blocks_size_before {
3056            self.add_memory_bytes(encoded_blocks_size_after - encoded_blocks_size_before);
3057        }
3058        if self.with_positions {
3059            let encoded_positions_size_before = self
3060                .encoded_position_blocks
3061                .as_ref()
3062                .map(|encoded| encoded.size())
3063                .unwrap_or(0usize);
3064            let released_tail_positions_bytes = self.tail_positions.size();
3065            let tail_position_block = std::mem::take(&mut self.tail_positions).finish();
3066            self.encoded_position_blocks_mut()
3067                .push_encoded_block(tail_position_block.as_slice());
3068            let encoded_positions_size_after = self
3069                .encoded_position_blocks
3070                .as_ref()
3071                .map(|encoded| encoded.size())
3072                .unwrap_or(0usize);
3073            if released_tail_positions_bytes > 0 {
3074                self.subtract_memory_bytes(released_tail_positions_bytes);
3075            }
3076            if encoded_positions_size_after > encoded_positions_size_before {
3077                self.add_memory_bytes(encoded_positions_size_after - encoded_positions_size_before);
3078            }
3079        }
3080        self.tail_entries.clear();
3081        Ok(())
3082    }
3083
3084    fn adjust_tail_positions_size(&mut self, old_size: usize) {
3085        let new_size = self.tail_positions.size();
3086        if new_size > old_size {
3087            self.add_memory_bytes(new_size - old_size);
3088        } else if old_size > new_size {
3089            self.subtract_memory_bytes(old_size - new_size);
3090        }
3091    }
3092
3093    fn add_memory_bytes(&mut self, bytes: usize) {
3094        self.memory_size_bytes = self
3095            .memory_size_bytes
3096            .checked_add(
3097                u32::try_from(bytes).expect("posting list memory size delta overflowed u32"),
3098            )
3099            .expect("posting list memory size overflowed u32");
3100    }
3101
3102    fn subtract_memory_bytes(&mut self, bytes: usize) {
3103        self.memory_size_bytes = self
3104            .memory_size_bytes
3105            .checked_sub(
3106                u32::try_from(bytes).expect("posting list memory size delta overflowed u32"),
3107            )
3108            .expect("posting list memory size underflowed u32");
3109    }
3110
3111    fn build_position_columns(
3112        positions: Option<CompressedPositionStorage>,
3113    ) -> Result<Vec<ArrayRef>> {
3114        let Some(positions) = positions else {
3115            return Ok(Vec::new());
3116        };
3117        match positions {
3118            CompressedPositionStorage::LegacyPerDoc(positions) => {
3119                Ok(vec![Arc::new(ListArray::try_new(
3120                    Arc::new(Field::new("item", positions.data_type().clone(), true)),
3121                    OffsetBuffer::new(ScalarBuffer::from(vec![0_i32, positions.len() as i32])),
3122                    Arc::new(positions) as ArrayRef,
3123                    None,
3124                )?) as ArrayRef])
3125            }
3126            CompressedPositionStorage::SharedStream(positions) => {
3127                let mut columns = Vec::with_capacity(2);
3128                columns.push(
3129                    Arc::new(LargeBinaryArray::from(vec![Some(positions.bytes())])) as ArrayRef,
3130                );
3131
3132                let mut offsets_builder = ListBuilder::new(UInt32Builder::new());
3133                for &offset in positions.block_offsets() {
3134                    offsets_builder.values().append_value(offset);
3135                }
3136                offsets_builder.append(true);
3137                columns.push(Arc::new(offsets_builder.finish()) as ArrayRef);
3138                Ok(columns)
3139            }
3140        }
3141    }
3142
3143    fn build_batch(
3144        self,
3145        compressed: LargeBinaryArray,
3146        max_score: f32,
3147        schema: SchemaRef,
3148        positions: Option<CompressedPositionStorage>,
3149    ) -> Result<RecordBatch> {
3150        let length = self.len();
3151        let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, compressed.len() as i32]));
3152        let mut columns = vec![
3153            Arc::new(ListArray::try_new(
3154                Arc::new(Field::new("item", datatypes::DataType::LargeBinary, true)),
3155                offsets,
3156                Arc::new(compressed),
3157                None,
3158            )?) as ArrayRef,
3159            Arc::new(Float32Array::from_iter_values(std::iter::once(max_score))) as ArrayRef,
3160            Arc::new(UInt32Array::from_iter_values(std::iter::once(
3161                length as u32,
3162            ))) as ArrayRef,
3163        ];
3164        columns.extend(Self::build_position_columns(positions)?);
3165
3166        let batch = RecordBatch::try_new(schema, columns)?;
3167        Ok(batch)
3168    }
3169
3170    fn build_legacy_positions(&self) -> Result<ListArray> {
3171        let mut positions_builder = ListBuilder::new(LargeBinaryBuilder::new());
3172        self.for_each_entry(|_doc_id, frequency, positions| {
3173            let positions = positions.ok_or_else(|| {
3174                Error::index(format!(
3175                    "legacy position writer missing positions for frequency {}",
3176                    frequency
3177                ))
3178            })?;
3179            let compressed = super::encoding::compress_positions(positions.as_slice())?;
3180            for block_idx in 0..compressed.len() {
3181                positions_builder
3182                    .values()
3183                    .append_value(compressed.value(block_idx));
3184            }
3185            positions_builder.append(true);
3186            Ok::<(), Error>(())
3187        })?;
3188        Ok(positions_builder.finish())
3189    }
3190
3191    pub(super) fn append_to_batch_with_docs(
3192        self,
3193        docs: &DocSet,
3194        batch_builder: &mut PostingListBatchBuilder,
3195        format_version: InvertedListFormatVersion,
3196    ) -> Result<()> {
3197        let legacy_positions =
3198            if self.with_positions && !format_version.uses_shared_position_stream() {
3199                Some(self.build_legacy_positions()?)
3200            } else {
3201                None
3202            };
3203        let Self {
3204            with_positions,
3205            posting_tail_codec,
3206            encoded_blocks,
3207            encoded_position_blocks,
3208            tail_entries,
3209            tail_positions,
3210            open_doc_id,
3211            open_doc_frequency,
3212            open_doc_last_position,
3213            len,
3214            ..
3215        } = self;
3216        debug_assert!(open_doc_id.is_none());
3217        debug_assert_eq!(open_doc_frequency, 0);
3218        debug_assert!(open_doc_last_position.is_none());
3219        let parts = PostingListParts {
3220            with_positions,
3221            posting_tail_codec,
3222            length: len as usize,
3223            encoded_blocks: encoded_blocks
3224                .map(|encoded_blocks| *encoded_blocks)
3225                .unwrap_or_default(),
3226            encoded_position_blocks: encoded_position_blocks
3227                .map(|encoded_positions| *encoded_positions)
3228                .unwrap_or_default(),
3229            tail_entries: tail_entries.as_slice(),
3230            tail_position_block: with_positions.then(|| tail_positions.finish()),
3231        };
3232        let (compressed, shared_positions, max_score) =
3233            Self::build_compressed_with_scores_from_parts(parts, docs)?;
3234        let positions = match legacy_positions {
3235            Some(positions) => Some(CompressedPositionStorage::LegacyPerDoc(positions)),
3236            None => shared_positions.map(CompressedPositionStorage::SharedStream),
3237        };
3238        batch_builder.append(compressed, max_score, len, positions.as_ref())
3239    }
3240
3241    fn extend_tail_components(
3242        tail_entries: &[RawDocInfo],
3243        doc_ids: &mut Vec<u32>,
3244        frequencies: &mut Vec<u32>,
3245    ) {
3246        doc_ids.clear();
3247        frequencies.clear();
3248        doc_ids.extend(tail_entries.iter().map(|entry| entry.doc_id));
3249        frequencies.extend(tail_entries.iter().map(|entry| entry.frequency));
3250    }
3251
3252    fn build_compressed_with_scores_from_parts(
3253        parts: PostingListParts<'_>,
3254        docs: &DocSet,
3255    ) -> Result<(LargeBinaryArray, Option<SharedPositionStream>, f32)> {
3256        let PostingListParts {
3257            with_positions,
3258            posting_tail_codec,
3259            length,
3260            mut encoded_blocks,
3261            mut encoded_position_blocks,
3262            tail_entries,
3263            tail_position_block,
3264        } = parts;
3265        let avgdl = docs.average_length();
3266        let idf_scale = idf(length, docs.len()) * (K1 + 1.0);
3267        let mut max_score = f32::MIN;
3268        let mut doc_ids = Vec::with_capacity(BLOCK_SIZE);
3269        let mut frequencies = Vec::with_capacity(BLOCK_SIZE);
3270
3271        for index in 0..encoded_blocks.len() {
3272            let block = encoded_blocks.block(index);
3273            doc_ids.clear();
3274            frequencies.clear();
3275            super::encoding::decode_full_posting_block(block, &mut doc_ids, &mut frequencies);
3276            let block_score = compute_block_score(
3277                docs,
3278                avgdl,
3279                idf_scale,
3280                doc_ids.iter().copied(),
3281                frequencies.iter().copied(),
3282            );
3283            max_score = max_score.max(block_score);
3284            encoded_blocks.set_block_score(index, block_score);
3285        }
3286
3287        if !tail_entries.is_empty() {
3288            Self::extend_tail_components(tail_entries, &mut doc_ids, &mut frequencies);
3289            let block_score = compute_block_score(
3290                docs,
3291                avgdl,
3292                idf_scale,
3293                doc_ids.iter().copied(),
3294                frequencies.iter().copied(),
3295            );
3296            max_score = max_score.max(block_score);
3297            encoded_blocks.append_remainder_block_with_codec(
3298                doc_ids.as_slice(),
3299                frequencies.as_slice(),
3300                posting_tail_codec,
3301            )?;
3302            encoded_blocks.set_block_score(encoded_blocks.len() - 1, block_score);
3303            if with_positions {
3304                encoded_position_blocks.push_encoded_block(
3305                    tail_position_block
3306                        .as_deref()
3307                        .expect("tail position block must exist for postings with positions"),
3308                );
3309            }
3310        }
3311
3312        Ok((
3313            encoded_blocks.into_array(),
3314            with_positions.then(|| encoded_position_blocks.into_stream()),
3315            max_score,
3316        ))
3317    }
3318
3319    fn build_compressed_with_block_scores_from_parts(
3320        with_positions: bool,
3321        posting_tail_codec: PostingTailCodec,
3322        mut encoded_blocks: EncodedBlocks,
3323        mut encoded_position_blocks: EncodedPositionBlocks,
3324        tail_entries: &[RawDocInfo],
3325        tail_position_block: Option<Vec<u8>>,
3326        mut block_max_scores: impl Iterator<Item = f32>,
3327    ) -> Result<(LargeBinaryArray, Option<SharedPositionStream>, f32)> {
3328        let mut max_score = f32::MIN;
3329        let mut doc_ids = Vec::with_capacity(BLOCK_SIZE);
3330        let mut frequencies = Vec::with_capacity(BLOCK_SIZE);
3331
3332        for index in 0..encoded_blocks.len() {
3333            let block_score = block_max_scores
3334                .next()
3335                .ok_or_else(|| Error::index("missing block max score".to_owned()))?;
3336            max_score = max_score.max(block_score);
3337            encoded_blocks.set_block_score(index, block_score);
3338        }
3339
3340        if !tail_entries.is_empty() {
3341            let block_score = block_max_scores
3342                .next()
3343                .ok_or_else(|| Error::index("missing tail block max score".to_owned()))?;
3344            max_score = max_score.max(block_score);
3345            Self::extend_tail_components(tail_entries, &mut doc_ids, &mut frequencies);
3346            encoded_blocks.append_remainder_block_with_codec(
3347                doc_ids.as_slice(),
3348                frequencies.as_slice(),
3349                posting_tail_codec,
3350            )?;
3351            encoded_blocks.set_block_score(encoded_blocks.len() - 1, block_score);
3352            if with_positions {
3353                encoded_position_blocks.push_encoded_block(
3354                    tail_position_block
3355                        .as_deref()
3356                        .expect("tail position block must exist for postings with positions"),
3357                );
3358            }
3359        }
3360
3361        Ok((
3362            encoded_blocks.into_array(),
3363            with_positions.then(|| encoded_position_blocks.into_stream()),
3364            max_score,
3365        ))
3366    }
3367
3368    pub fn to_batch(self, block_max_scores: Vec<f32>) -> Result<RecordBatch> {
3369        let format_version = if self.posting_tail_codec == PostingTailCodec::Fixed32 {
3370            InvertedListFormatVersion::V1
3371        } else {
3372            InvertedListFormatVersion::V2
3373        };
3374        let schema = inverted_list_schema_for_version(self.has_positions(), format_version);
3375        let legacy_positions =
3376            if self.with_positions && !format_version.uses_shared_position_stream() {
3377                Some(self.build_legacy_positions()?)
3378            } else {
3379                None
3380            };
3381        let Self {
3382            with_positions,
3383            posting_tail_codec,
3384            encoded_blocks,
3385            encoded_position_blocks,
3386            tail_entries,
3387            tail_positions,
3388            open_doc_id,
3389            open_doc_frequency,
3390            open_doc_last_position,
3391            len,
3392            ..
3393        } = self;
3394        debug_assert!(open_doc_id.is_none());
3395        debug_assert_eq!(open_doc_frequency, 0);
3396        debug_assert!(open_doc_last_position.is_none());
3397        let (compressed, shared_positions, max_score) =
3398            Self::build_compressed_with_block_scores_from_parts(
3399                with_positions,
3400                posting_tail_codec,
3401                encoded_blocks
3402                    .map(|encoded_blocks| *encoded_blocks)
3403                    .unwrap_or_default(),
3404                encoded_position_blocks
3405                    .map(|encoded_positions| *encoded_positions)
3406                    .unwrap_or_default(),
3407                tail_entries.as_slice(),
3408                with_positions.then(|| tail_positions.finish()),
3409                block_max_scores.into_iter(),
3410            )?;
3411        let builder = Self {
3412            with_positions,
3413            posting_tail_codec,
3414            encoded_blocks: None,
3415            encoded_position_blocks: None,
3416            tail_entries: Vec::new(),
3417            tail_positions: PositionBlockBuilder::default(),
3418            open_doc_id: None,
3419            open_doc_frequency: 0,
3420            open_doc_last_position: None,
3421            memory_size_bytes: 0,
3422            len,
3423        };
3424        let positions = match legacy_positions {
3425            Some(positions) => Some(CompressedPositionStorage::LegacyPerDoc(positions)),
3426            None => shared_positions.map(CompressedPositionStorage::SharedStream),
3427        };
3428        builder.build_batch(compressed, max_score, schema, positions)
3429    }
3430
3431    pub fn to_batch_with_docs(self, docs: &DocSet, schema: SchemaRef) -> Result<RecordBatch> {
3432        let format_version = if schema.column_with_name(POSITION_COL).is_some()
3433            && schema.column_with_name(COMPRESSED_POSITION_COL).is_none()
3434        {
3435            InvertedListFormatVersion::V1
3436        } else {
3437            InvertedListFormatVersion::V2
3438        };
3439        let legacy_positions =
3440            if self.with_positions && !format_version.uses_shared_position_stream() {
3441                Some(self.build_legacy_positions()?)
3442            } else {
3443                None
3444            };
3445        let Self {
3446            with_positions,
3447            posting_tail_codec,
3448            encoded_blocks,
3449            encoded_position_blocks,
3450            tail_entries,
3451            tail_positions,
3452            open_doc_id,
3453            open_doc_frequency,
3454            open_doc_last_position,
3455            len,
3456            ..
3457        } = self;
3458        debug_assert!(open_doc_id.is_none());
3459        debug_assert_eq!(open_doc_frequency, 0);
3460        debug_assert!(open_doc_last_position.is_none());
3461        let parts = PostingListParts {
3462            with_positions,
3463            posting_tail_codec,
3464            length: len as usize,
3465            encoded_blocks: encoded_blocks
3466                .map(|encoded_blocks| *encoded_blocks)
3467                .unwrap_or_default(),
3468            encoded_position_blocks: encoded_position_blocks
3469                .map(|encoded_positions| *encoded_positions)
3470                .unwrap_or_default(),
3471            tail_entries: tail_entries.as_slice(),
3472            tail_position_block: with_positions.then(|| tail_positions.finish()),
3473        };
3474        let (compressed, shared_positions, max_score) =
3475            Self::build_compressed_with_scores_from_parts(parts, docs)?;
3476        let builder = Self {
3477            with_positions,
3478            posting_tail_codec,
3479            encoded_blocks: None,
3480            encoded_position_blocks: None,
3481            tail_entries: Vec::new(),
3482            tail_positions: PositionBlockBuilder::default(),
3483            open_doc_id: None,
3484            open_doc_frequency: 0,
3485            open_doc_last_position: None,
3486            memory_size_bytes: 0,
3487            len,
3488        };
3489        let positions = match legacy_positions {
3490            Some(positions) => Some(CompressedPositionStorage::LegacyPerDoc(positions)),
3491            None => shared_positions.map(CompressedPositionStorage::SharedStream),
3492        };
3493        builder.build_batch(compressed, max_score, schema, positions)
3494    }
3495
3496    pub fn remap(&mut self, removed: &[u32]) {
3497        let mut cursor = 0;
3498        let mut new_builder =
3499            Self::new_with_posting_tail_codec(self.has_positions(), self.posting_tail_codec);
3500        for (doc_id, freq, positions) in self.iter() {
3501            while cursor < removed.len() && removed[cursor] < doc_id {
3502                cursor += 1;
3503            }
3504            if cursor < removed.len() && removed[cursor] == doc_id {
3505                continue;
3506            }
3507            let positions = match positions {
3508                Some(positions) => PositionRecorder::Position(positions.into()),
3509                None => PositionRecorder::Count(freq),
3510            };
3511            new_builder.add(doc_id - cursor as u32, positions);
3512        }
3513
3514        *self = new_builder;
3515    }
3516}
3517
3518fn compute_block_score(
3519    docs: &DocSet,
3520    avgdl: f32,
3521    idf_scale: f32,
3522    doc_ids: impl Iterator<Item = u32>,
3523    frequencies: impl Iterator<Item = u32>,
3524) -> f32 {
3525    let mut block_max_score = f32::MIN;
3526    for (doc_id, freq) in doc_ids.zip(frequencies) {
3527        let doc_norm = K1 * (1.0 - B + B * docs.num_tokens(doc_id) as f32 / avgdl);
3528        let freq = freq as f32;
3529        let score = freq / (freq + doc_norm);
3530        block_max_score = block_max_score.max(score);
3531    }
3532    block_max_score * idf_scale
3533}
3534
3535#[derive(Debug, Clone, DeepSizeOf, Copy)]
3536pub enum DocInfo {
3537    Located(LocatedDocInfo),
3538    Raw(RawDocInfo),
3539}
3540
3541impl DocInfo {
3542    pub fn doc_id(&self) -> u64 {
3543        match self {
3544            Self::Raw(info) => info.doc_id as u64,
3545            Self::Located(info) => info.row_id,
3546        }
3547    }
3548
3549    pub fn frequency(&self) -> u32 {
3550        match self {
3551            Self::Raw(info) => info.frequency,
3552            Self::Located(info) => info.frequency as u32,
3553        }
3554    }
3555}
3556
3557impl Eq for DocInfo {}
3558
3559impl PartialEq for DocInfo {
3560    fn eq(&self, other: &Self) -> bool {
3561        self.doc_id() == other.doc_id()
3562    }
3563}
3564
3565impl PartialOrd for DocInfo {
3566    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
3567        Some(self.cmp(other))
3568    }
3569}
3570
3571impl Ord for DocInfo {
3572    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
3573        self.doc_id().cmp(&other.doc_id())
3574    }
3575}
3576
3577#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
3578pub struct LocatedDocInfo {
3579    pub row_id: u64,
3580    pub frequency: f32,
3581}
3582
3583impl LocatedDocInfo {
3584    pub fn new(row_id: u64, frequency: f32) -> Self {
3585        Self { row_id, frequency }
3586    }
3587}
3588
3589impl Eq for LocatedDocInfo {}
3590
3591impl PartialEq for LocatedDocInfo {
3592    fn eq(&self, other: &Self) -> bool {
3593        self.row_id == other.row_id
3594    }
3595}
3596
3597impl PartialOrd for LocatedDocInfo {
3598    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
3599        Some(self.cmp(other))
3600    }
3601}
3602
3603impl Ord for LocatedDocInfo {
3604    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
3605        self.row_id.cmp(&other.row_id)
3606    }
3607}
3608
3609#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
3610pub struct RawDocInfo {
3611    pub doc_id: u32,
3612    pub frequency: u32,
3613}
3614
3615impl RawDocInfo {
3616    pub fn new(doc_id: u32, frequency: u32) -> Self {
3617        Self { doc_id, frequency }
3618    }
3619}
3620
3621impl Eq for RawDocInfo {}
3622
3623impl PartialEq for RawDocInfo {
3624    fn eq(&self, other: &Self) -> bool {
3625        self.doc_id == other.doc_id
3626    }
3627}
3628
3629impl PartialOrd for RawDocInfo {
3630    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
3631        Some(self.cmp(other))
3632    }
3633}
3634
3635impl Ord for RawDocInfo {
3636    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
3637        self.doc_id.cmp(&other.doc_id)
3638    }
3639}
3640
3641// DocSet is a mapping from row ids to the number of tokens in the document
3642// It's used to sort the documents by the bm25 score
3643#[derive(Debug, Clone, Default, DeepSizeOf)]
3644pub struct DocSet {
3645    row_ids: Vec<u64>,
3646    num_tokens: Vec<u32>,
3647    // (row_id, doc_id) pairs sorted by row_id
3648    inv: Vec<(u64, u32)>,
3649
3650    total_tokens: u64,
3651}
3652
3653impl DocSet {
3654    #[inline]
3655    pub fn len(&self) -> usize {
3656        self.row_ids.len()
3657    }
3658
3659    pub fn is_empty(&self) -> bool {
3660        self.len() == 0
3661    }
3662
3663    pub fn iter(&self) -> impl Iterator<Item = (&u64, &u32)> {
3664        self.row_ids.iter().zip(self.num_tokens.iter())
3665    }
3666
3667    pub fn row_id(&self, doc_id: u32) -> u64 {
3668        self.row_ids[doc_id as usize]
3669    }
3670
3671    pub fn doc_id(&self, row_id: u64) -> Option<u64> {
3672        if self.inv.is_empty() {
3673            // in legacy format, the row id is doc id
3674            match self.row_ids.binary_search(&row_id) {
3675                Ok(_) => Some(row_id),
3676                Err(_) => None,
3677            }
3678        } else {
3679            match self.inv.binary_search_by_key(&row_id, |x| x.0) {
3680                Ok(idx) => Some(self.inv[idx].1 as u64),
3681                Err(_) => None,
3682            }
3683        }
3684    }
3685    pub fn total_tokens_num(&self) -> u64 {
3686        self.total_tokens
3687    }
3688
3689    #[inline]
3690    pub fn average_length(&self) -> f32 {
3691        self.total_tokens as f32 / self.len() as f32
3692    }
3693
3694    pub fn calculate_block_max_scores<'a>(
3695        &self,
3696        doc_ids: impl Iterator<Item = &'a u32>,
3697        freqs: impl Iterator<Item = &'a u32>,
3698    ) -> Vec<f32> {
3699        let avgdl = self.average_length();
3700        let length = doc_ids.size_hint().0;
3701        let num_blocks = length.div_ceil(BLOCK_SIZE);
3702        let mut block_max_scores = Vec::with_capacity(num_blocks);
3703        let idf_scale = idf(length, self.len()) * (K1 + 1.0);
3704        let mut max_score = f32::MIN;
3705        for (i, (doc_id, freq)) in doc_ids.zip(freqs).enumerate() {
3706            let doc_norm = K1 * (1.0 - B + B * self.num_tokens(*doc_id) as f32 / avgdl);
3707            let freq = *freq as f32;
3708            let score = freq / (freq + doc_norm);
3709            if score > max_score {
3710                max_score = score;
3711            }
3712            if (i + 1) % BLOCK_SIZE == 0 {
3713                max_score *= idf_scale;
3714                block_max_scores.push(max_score);
3715                max_score = f32::MIN;
3716            }
3717        }
3718        if !length.is_multiple_of(BLOCK_SIZE) {
3719            max_score *= idf_scale;
3720            block_max_scores.push(max_score);
3721        }
3722        block_max_scores
3723    }
3724
3725    pub fn to_batch(&self) -> Result<RecordBatch> {
3726        let row_id_col = UInt64Array::from_iter_values(self.row_ids.iter().cloned());
3727        let num_tokens_col = UInt32Array::from_iter_values(self.num_tokens.iter().cloned());
3728
3729        let schema = arrow_schema::Schema::new(vec![
3730            arrow_schema::Field::new(ROW_ID, DataType::UInt64, false),
3731            arrow_schema::Field::new(NUM_TOKEN_COL, DataType::UInt32, false),
3732        ]);
3733
3734        let batch = RecordBatch::try_new(
3735            Arc::new(schema),
3736            vec![
3737                Arc::new(row_id_col) as ArrayRef,
3738                Arc::new(num_tokens_col) as ArrayRef,
3739            ],
3740        )?;
3741        Ok(batch)
3742    }
3743
3744    pub async fn load(
3745        reader: Arc<dyn IndexReader>,
3746        is_legacy: bool,
3747        frag_reuse_index: Option<Arc<FragReuseIndex>>,
3748    ) -> Result<Self> {
3749        let batch = reader.read_range(0..reader.num_rows(), None).await?;
3750        let row_id_col = batch[ROW_ID].as_primitive::<datatypes::UInt64Type>();
3751        let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::<datatypes::UInt32Type>();
3752
3753        // for legacy format, the row id is doc id; sorting keeps binary search viable
3754        if is_legacy {
3755            let (row_ids, num_tokens): (Vec<_>, Vec<_>) = row_id_col
3756                .values()
3757                .iter()
3758                .filter_map(|id| {
3759                    if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
3760                        frag_reuse_index_ref.remap_row_id(*id)
3761                    } else {
3762                        Some(*id)
3763                    }
3764                })
3765                .zip(num_tokens_col.values().iter())
3766                .sorted_unstable_by_key(|x| x.0)
3767                .unzip();
3768
3769            let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
3770            return Ok(Self {
3771                row_ids,
3772                num_tokens,
3773                inv: Vec::new(),
3774                total_tokens,
3775            });
3776        }
3777
3778        // if frag reuse happened, we'll need to remap the row_ids. And after row_ids been
3779        // remapped, we'll need resort to make sure binary_search works.
3780        if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
3781            let mut row_ids = Vec::with_capacity(row_id_col.len());
3782            let mut num_tokens = Vec::with_capacity(num_tokens_col.len());
3783            for (row_id, num_token) in row_id_col.values().iter().zip(num_tokens_col.values()) {
3784                if let Some(new_row_id) = frag_reuse_index_ref.remap_row_id(*row_id) {
3785                    row_ids.push(new_row_id);
3786                    num_tokens.push(*num_token);
3787                }
3788            }
3789
3790            let mut inv: Vec<(u64, u32)> = row_ids
3791                .iter()
3792                .enumerate()
3793                .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
3794                .collect();
3795            inv.sort_unstable_by_key(|entry| entry.0);
3796
3797            let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
3798            return Ok(Self {
3799                row_ids,
3800                num_tokens,
3801                inv,
3802                total_tokens,
3803            });
3804        }
3805
3806        let row_ids = row_id_col.values().to_vec();
3807        let num_tokens = num_tokens_col.values().to_vec();
3808        let mut inv: Vec<(u64, u32)> = row_ids
3809            .iter()
3810            .enumerate()
3811            .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
3812            .collect();
3813        if !row_ids.is_sorted() {
3814            inv.sort_unstable_by_key(|entry| entry.0);
3815        }
3816        let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
3817        Ok(Self {
3818            row_ids,
3819            num_tokens,
3820            inv,
3821            total_tokens,
3822        })
3823    }
3824
3825    // remap the row ids to the new row ids
3826    // returns the removed doc ids
3827    pub fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Vec<u32> {
3828        let mut removed = Vec::new();
3829        let len = self.len();
3830        let row_ids = std::mem::replace(&mut self.row_ids, Vec::with_capacity(len));
3831        let num_tokens = std::mem::replace(&mut self.num_tokens, Vec::with_capacity(len));
3832        for (doc_id, (row_id, num_token)) in std::iter::zip(row_ids, num_tokens).enumerate() {
3833            match mapping.get(&row_id) {
3834                Some(Some(new_row_id)) => {
3835                    self.row_ids.push(*new_row_id);
3836                    self.num_tokens.push(num_token);
3837                }
3838                Some(None) => {
3839                    removed.push(doc_id as u32);
3840                }
3841                None => {
3842                    self.row_ids.push(row_id);
3843                    self.num_tokens.push(num_token);
3844                }
3845            }
3846        }
3847        removed
3848    }
3849
3850    #[inline]
3851    pub fn num_tokens(&self, doc_id: u32) -> u32 {
3852        self.num_tokens[doc_id as usize]
3853    }
3854
3855    // this can be used only if it's a legacy format,
3856    // which store the sorted row ids so that we can use binary search
3857    #[inline]
3858    pub fn num_tokens_by_row_id(&self, row_id: u64) -> u32 {
3859        self.row_ids
3860            .binary_search(&row_id)
3861            .map(|idx| self.num_tokens[idx])
3862            .unwrap_or(0)
3863    }
3864
3865    // append a document to the doc set
3866    // returns the doc_id (the number of documents before appending)
3867    pub fn append(&mut self, row_id: u64, num_tokens: u32) -> u32 {
3868        self.row_ids.push(row_id);
3869        self.num_tokens.push(num_tokens);
3870        self.total_tokens += num_tokens as u64;
3871        self.row_ids.len() as u32 - 1
3872    }
3873
3874    pub(crate) fn memory_size(&self) -> usize {
3875        self.row_ids.capacity() * std::mem::size_of::<u64>()
3876            + self.num_tokens.capacity() * std::mem::size_of::<u32>()
3877            + self.inv.capacity() * std::mem::size_of::<(u64, u32)>()
3878    }
3879}
3880
3881pub fn flat_full_text_search(
3882    batches: &[&RecordBatch],
3883    doc_col: &str,
3884    query: &str,
3885    tokenizer: Option<Box<dyn LanceTokenizer>>,
3886) -> Result<Vec<u64>> {
3887    if batches.is_empty() {
3888        return Ok(vec![]);
3889    }
3890
3891    if is_phrase_query(query) {
3892        return Err(Error::invalid_input(
3893            "phrase query is not supported for flat full text search, try using FTS index",
3894        ));
3895    }
3896
3897    match batches[0][doc_col].data_type() {
3898        DataType::Utf8 => do_flat_full_text_search::<i32>(batches, doc_col, query, tokenizer),
3899        DataType::LargeUtf8 => do_flat_full_text_search::<i64>(batches, doc_col, query, tokenizer),
3900        data_type => Err(Error::invalid_input(format!(
3901            "unsupported data type {} for inverted index",
3902            data_type
3903        ))),
3904    }
3905}
3906
3907fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
3908    batches: &[&RecordBatch],
3909    doc_col: &str,
3910    query: &str,
3911    tokenizer: Option<Box<dyn LanceTokenizer>>,
3912) -> Result<Vec<u64>> {
3913    let mut results = Vec::new();
3914    let mut tokenizer =
3915        tokenizer.unwrap_or_else(|| InvertedIndexParams::default().build().unwrap());
3916    let query_tokens = collect_query_tokens(query, &mut tokenizer);
3917
3918    for batch in batches {
3919        let row_id_array = batch[ROW_ID].as_primitive::<UInt64Type>();
3920        let doc_array = batch[doc_col].as_string::<Offset>();
3921        for i in 0..row_id_array.len() {
3922            let doc = doc_array.value(i);
3923            if has_query_token(doc, &mut tokenizer, &query_tokens) {
3924                results.push(row_id_array.value(i));
3925                // What is this assertion for?  Why would doc contain query?  Don't we reach
3926                // here only if they share at least one token?  Why is it not debug_assert?
3927                assert!(doc.contains(query));
3928            }
3929        }
3930    }
3931
3932    Ok(results)
3933}
3934
3935const FLAT_ROW_ID_COL_IDX: usize = 0;
3936const FLAT_ALL_TOKENS_COL_IDX: usize = 1;
3937const FLAT_QUERY_TOKEN_COUNTS_COL_IDX: usize = 2;
3938
3939/// If we accumulate this many bytes we warn the user they probably want to use an FTS index instead.
3940const BYTES_ACCUMULATED_WARNING_THRESHOLD: u64 = 1024 * 1024 * 1024; // 1GB
3941
3942/// Consumes a stream of record batches and produces token counts
3943///
3944/// The resulting batch will have three columns:
3945/// - row_id: the row id of the document
3946/// - all_tokens: the total number of tokens in the document
3947/// - query_token_counts: a fixed size list of the count of each query token in the document
3948///
3949/// This is an unbounded accumulation, however, for most queries, the per-row
3950/// growth will be fairly small.  As a result we can process millions of tokens
3951/// with fairly modest memory usage.
3952///
3953/// However, it is unwise to do a flat search across billions of rows.  An FTS
3954/// index should be created instead.
3955async fn tokenize_and_count(
3956    input: impl Stream<Item = DataFusionResult<RecordBatch>> + Send,
3957    tokenizer: Box<dyn LanceTokenizer>,
3958    query_tokens: Arc<Tokens>,
3959    doc_col_idx: usize,
3960) -> DataFusionResult<RecordBatch> {
3961    let output_schema = Arc::new(Schema::new(vec![
3962        ROW_ID_FIELD.clone(),
3963        Field::new("all_tokens", DataType::UInt64, false),
3964        Field::new(
3965            "query_token_counts",
3966            DataType::FixedSizeList(
3967                Arc::new(Field::new("item", DataType::UInt64, true)),
3968                query_tokens.len() as i32,
3969            ),
3970            false,
3971        ),
3972    ]));
3973    let output_schema_clone = output_schema.clone();
3974    let bytes_accumulated = Arc::new(AtomicU64::new(0));
3975    let bytes_warning_emitted = Arc::new(AtomicBool::new(false));
3976
3977    let batches = input
3978        .map(move |batch| {
3979            let mut tokenizer = tokenizer.box_clone();
3980            let output_schema = output_schema.clone();
3981            let query_tokens = query_tokens.clone();
3982            let bytes_accumulated = bytes_accumulated.clone();
3983            let bytes_warning_emitted = bytes_warning_emitted.clone();
3984            spawn_cpu(move || {
3985                let batch = batch?;
3986                let mut all_token_counts = UInt64Builder::with_capacity(batch.num_rows());
3987                let mut query_token_counts = FixedSizeListBuilder::with_capacity(
3988                    UInt64Builder::with_capacity(batch.num_rows() * query_tokens.len()),
3989                    query_tokens.len() as i32,
3990                    batch.num_rows(),
3991                );
3992                let mut temp_query_token_counts = Vec::with_capacity(query_tokens.len());
3993                let doc_iter = iter_str_array(batch.column(doc_col_idx));
3994                for doc in doc_iter {
3995                    let Some(doc) = doc else {
3996                        all_token_counts.append_value(0);
3997                        query_token_counts
3998                            .values()
3999                            .append_value_n(0, query_tokens.len());
4000                        query_token_counts.append(true);
4001                        continue;
4002                    };
4003
4004                    temp_query_token_counts.clear();
4005                    temp_query_token_counts.extend(std::iter::repeat_n(0, query_tokens.len()));
4006
4007                    let mut stream = tokenizer.token_stream_for_doc(doc);
4008                    let mut all_tokens = 0;
4009                    while let Some(token) = stream.next() {
4010                        all_tokens += 1;
4011                        if let Some(token_index) = query_tokens.token_index(&token.text) {
4012                            temp_query_token_counts[token_index] += 1;
4013                        }
4014                    }
4015                    all_token_counts.append_value(all_tokens);
4016                    for count in temp_query_token_counts.iter().copied() {
4017                        query_token_counts.values().append_value(count);
4018                    }
4019                    query_token_counts.append(true);
4020                }
4021                let row_ids = batch[ROW_ID].clone();
4022                let all_token_counts = all_token_counts.finish();
4023                let query_token_counts = query_token_counts.finish();
4024                let result_batch = RecordBatch::try_new(
4025
4026                    output_schema,
4027                    vec![
4028                        row_ids,
4029                        Arc::new(all_token_counts) as ArrayRef,
4030                        Arc::new(query_token_counts) as ArrayRef,
4031                    ],
4032                )?;
4033                let bytes_accumulated = bytes_accumulated.fetch_add(result_batch.get_array_memory_size() as u64, Ordering::Relaxed);
4034                if bytes_accumulated > BYTES_ACCUMULATED_WARNING_THRESHOLD && !bytes_warning_emitted.swap(true, Ordering::Relaxed) {
4035                    tracing::warn!("Flat full text search is accumulating a large number of bytes.  Consider using an FTS index instead.");
4036                }
4037
4038                DataFusionResult::Ok(result_batch)
4039            })
4040        })
4041        .buffered(get_num_compute_intensive_cpus())
4042        .try_collect::<Vec<_>>()
4043        .await?;
4044
4045    Ok(arrow::compute::concat_batches(
4046        &output_schema_clone,
4047        &batches,
4048    )?)
4049}
4050
4051/// Initialize the BM25 scorer
4052///
4053/// In order to calculate BM25 scores we need to know token counts for the entire corpus.  We extract these from the
4054/// counted input of the flat search combined with any counts recorded for the indexed portion.
4055fn initialize_scorer(
4056    base_scorer: Option<&MemBM25Scorer>,
4057    query_tokens: &Tokens,
4058    counted_input: &RecordBatch,
4059) -> MemBM25Scorer {
4060    let mut total_tokens = 0;
4061    let mut num_docs = 0;
4062    let mut all_token_counts = vec![0; query_tokens.len()];
4063
4064    if let Some(base_scorer) = base_scorer {
4065        total_tokens += base_scorer.total_tokens;
4066        num_docs += base_scorer.num_docs;
4067        for (token_index, token) in query_tokens.into_iter().enumerate() {
4068            all_token_counts[token_index] = base_scorer.num_docs_containing_token(token) as u64;
4069        }
4070    }
4071
4072    num_docs += counted_input.num_rows();
4073    total_tokens += arrow::compute::sum(
4074        counted_input
4075            .column(FLAT_ALL_TOKENS_COL_IDX)
4076            .as_primitive::<UInt64Type>(),
4077    )
4078    .unwrap_or_default();
4079
4080    let mut input_token_counters = counted_input
4081        .column(FLAT_QUERY_TOKEN_COUNTS_COL_IDX)
4082        .as_fixed_size_list()
4083        .values()
4084        .as_primitive::<UInt64Type>()
4085        .values()
4086        .iter()
4087        .copied();
4088
4089    for _ in 0..counted_input.num_rows() {
4090        for token_count in all_token_counts.iter_mut() {
4091            *token_count += input_token_counters.next().unwrap_or_default();
4092        }
4093    }
4094
4095    let token_counts_map = all_token_counts
4096        .into_iter()
4097        .enumerate()
4098        .map(|(token_index, count)| {
4099            (
4100                query_tokens.get_token(token_index).to_string(),
4101                count as usize,
4102            )
4103        })
4104        .collect::<HashMap<String, usize>>();
4105    MemBM25Scorer::new(total_tokens, num_docs, token_counts_map)
4106}
4107
4108fn flat_bm25_score(
4109    query_tokens: &Tokens,
4110    counted_input: &RecordBatch,
4111    scorer: &MemBM25Scorer,
4112) -> Result<RecordBatch> {
4113    let mut row_ids_builder = UInt64Builder::with_capacity(counted_input.num_rows());
4114    let mut scores_builder = Float32Builder::with_capacity(counted_input.num_rows());
4115
4116    let mut row_ids_iter = counted_input
4117        .column(FLAT_ROW_ID_COL_IDX)
4118        .as_primitive::<UInt64Type>()
4119        .values()
4120        .iter()
4121        .copied();
4122    let mut all_token_counts_iter = counted_input
4123        .column(FLAT_ALL_TOKENS_COL_IDX)
4124        .as_primitive::<UInt64Type>()
4125        .values()
4126        .iter()
4127        .copied();
4128    let mut query_token_counts_iter = counted_input
4129        .column(FLAT_QUERY_TOKEN_COUNTS_COL_IDX)
4130        .as_fixed_size_list()
4131        .values()
4132        .as_primitive::<UInt64Type>()
4133        .values()
4134        .iter()
4135        .copied();
4136    for _ in 0..counted_input.num_rows() {
4137        let num_tokens_in_doc = all_token_counts_iter.next().expect_ok()?;
4138        let row_id = row_ids_iter.next().expect_ok()?;
4139        if num_tokens_in_doc == 0 {
4140            for _ in query_tokens {
4141                query_token_counts_iter.next().expect_ok()?;
4142            }
4143            continue;
4144        }
4145        let doc_norm = K1 * (1.0 - B + B * num_tokens_in_doc as f32 / scorer.avg_doc_length());
4146        let mut score = 0.0;
4147        for token in query_tokens {
4148            let freq = query_token_counts_iter.next().expect_ok()? as f32;
4149            let idf = idf(scorer.num_docs_containing_token(token), scorer.num_docs());
4150            score += idf * (freq * (K1 + 1.0) / (freq + doc_norm));
4151        }
4152        if score > 0.0 {
4153            row_ids_builder.append_value(row_id);
4154            scores_builder.append_value(score);
4155        }
4156    }
4157
4158    let row_ids = row_ids_builder.finish();
4159    let scores = scores_builder.finish();
4160    let batch = RecordBatch::try_new(
4161        FTS_SCHEMA.clone(),
4162        vec![Arc::new(row_ids) as ArrayRef, Arc::new(scores) as ArrayRef],
4163    )?;
4164    Ok(batch)
4165}
4166
4167pub async fn flat_bm25_search_stream(
4168    input: SendableRecordBatchStream,
4169    doc_col: String,
4170    query: String,
4171    tokenizer: Box<dyn LanceTokenizer>,
4172    base_scorer: Option<MemBM25Scorer>,
4173    target_batch_size: usize,
4174) -> DataFusionResult<SendableRecordBatchStream> {
4175    let mut tokenizer = tokenizer;
4176    let query_tokens = Arc::new(collect_query_tokens(&query, &mut tokenizer));
4177
4178    let input_schema = input.schema();
4179    let doc_col_idx = input_schema.index_of(&doc_col)?;
4180
4181    // Accumulate small batches until this threshold before dispatching a task.
4182    const ACCUMULATE_BYTES: usize = 256 * 1024;
4183    // Slice oversized batches down to roughly this size.
4184    const SLICE_BYTES: usize = 512 * 1024;
4185
4186    // Phase 1 - rechunk the input stream into appropriately sized chunks.  Tokenization is
4187    // fairly CPU-intensive, and we don't need too much data to justify a new thread task.
4188    let chunked = lance_arrow::stream::rechunk_stream_by_size(
4189        input,
4190        input_schema,
4191        ACCUMULATE_BYTES,
4192        SLICE_BYTES,
4193    );
4194
4195    // Phase 2 - For each row we need to know the total number of tokens and the count of each
4196    // of the query tokens.  For example, if the query is "book" and the row is "the book shop"
4197    // and we are tokenizing with a whitespace tokenizer, we need to know that there are 3 tokens
4198    // and the token book appears once.
4199    let counted_input =
4200        tokenize_and_count(chunked, tokenizer, query_tokens.clone(), doc_col_idx).await?;
4201
4202    // Phase 3 - Calculate final scores (this is fairly cheap, probably don't need to parallelize)
4203    let scorer = initialize_scorer(base_scorer.as_ref(), query_tokens.as_ref(), &counted_input);
4204    let scores = flat_bm25_score(query_tokens.as_ref(), &counted_input, &scorer)?;
4205
4206    // Finally we emit batches according to the target batch size
4207    let num_out_batches = scores.num_rows().div_ceil(target_batch_size);
4208    let mut batches = Vec::with_capacity(num_out_batches);
4209    for i in 0..num_out_batches {
4210        let start = i * target_batch_size;
4211        let len = (scores.num_rows() - start).min(target_batch_size);
4212        batches.push(Ok(scores.slice(start, len)));
4213    }
4214    Ok(Box::pin(RecordBatchStreamAdapter::new(
4215        FTS_SCHEMA.clone(),
4216        stream::iter(batches),
4217    )))
4218}
4219
4220pub fn is_phrase_query(query: &str) -> bool {
4221    query.starts_with('\"') && query.ends_with('\"')
4222}
4223
4224#[cfg(test)]
4225mod tests {
4226    use crate::scalar::inverted::document_tokenizer::DocType;
4227    use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
4228    use futures::stream;
4229    use lance_core::cache::LanceCache;
4230    use lance_core::utils::tempfile::TempObjDir;
4231    use lance_io::object_store::ObjectStore;
4232
4233    use crate::metrics::NoOpMetricsCollector;
4234    use crate::prefilter::NoFilter;
4235    use crate::scalar::ScalarIndex;
4236    use crate::scalar::inverted::builder::{InnerBuilder, PositionRecorder, inverted_list_schema};
4237    use crate::scalar::inverted::encoding::{
4238        compress_positions, compress_posting_list_with_tail_codec,
4239        decompress_posting_list_with_tail_codec, encode_position_stream_block_into,
4240    };
4241    use crate::scalar::inverted::query::{FtsSearchParams, Operator};
4242    use crate::scalar::lance_format::LanceIndexStore;
4243    use arrow::array::{AsArray, LargeBinaryBuilder, ListBuilder, UInt32Builder};
4244    use arrow::datatypes::{Float32Type, UInt32Type};
4245    use arrow_array::{ArrayRef, Float32Array, RecordBatch, StringArray, UInt32Array, UInt64Array};
4246    use arrow_schema::{DataType, Field, Schema};
4247    use std::collections::HashMap;
4248    use std::sync::Arc;
4249
4250    use super::*;
4251
4252    #[tokio::test]
4253    async fn test_posting_builder_remap() {
4254        let posting_tail_codec = PostingTailCodec::Fixed32;
4255        let mut builder =
4256            PostingListBuilder::new_with_posting_tail_codec(false, posting_tail_codec);
4257        let n = BLOCK_SIZE + 3;
4258        for i in 0..n {
4259            builder.add(i as u32, PositionRecorder::Count(1));
4260        }
4261        let removed = vec![5, 7];
4262        builder.remap(&removed);
4263
4264        let mut expected =
4265            PostingListBuilder::new_with_posting_tail_codec(false, posting_tail_codec);
4266        for i in 0..n - removed.len() {
4267            expected.add(i as u32, PositionRecorder::Count(1));
4268        }
4269        let expected_entries = expected.iter().collect::<Vec<_>>();
4270        let actual_entries = builder.iter().collect::<Vec<_>>();
4271        assert_eq!(actual_entries, expected_entries);
4272
4273        // BLOCK_SIZE + 3 elements should be reduced to BLOCK_SIZE + 1,
4274        // there are still 2 blocks.
4275        let batch = builder.to_batch(vec![1.0, 2.0]).unwrap();
4276        let (doc_ids, freqs) = decompress_posting_list_with_tail_codec(
4277            (n - removed.len()) as u32,
4278            batch[POSTING_COL]
4279                .as_list::<i32>()
4280                .value(0)
4281                .as_binary::<i64>(),
4282            posting_tail_codec,
4283        )
4284        .unwrap();
4285        assert!(
4286            doc_ids
4287                .iter()
4288                .zip(expected_entries.iter().map(|(doc_id, _, _)| doc_id))
4289                .all(|(a, b)| a == b)
4290        );
4291        assert!(
4292            freqs
4293                .iter()
4294                .zip(expected_entries.iter().map(|(_, freq, _)| freq))
4295                .all(|(a, b)| a == b)
4296        );
4297    }
4298
4299    #[test]
4300    fn test_posting_builder_size_tracking_matches_structure() {
4301        fn tracked_memory_size(builder: &PostingListBuilder) -> u64 {
4302            let encoded_blocks_size = builder
4303                .encoded_blocks
4304                .iter()
4305                .map(|encoded_blocks| std::mem::size_of::<EncodedBlocks>() + encoded_blocks.size())
4306                .sum::<usize>();
4307            let encoded_positions_size = builder
4308                .encoded_position_blocks
4309                .as_ref()
4310                .map(|positions| std::mem::size_of::<EncodedPositionBlocks>() + positions.size())
4311                .unwrap_or(0usize);
4312            (encoded_blocks_size
4313                + builder.tail_entries.capacity() * std::mem::size_of::<RawDocInfo>()
4314                + builder.tail_positions.size()
4315                + encoded_positions_size) as u64
4316        }
4317
4318        let mut builder = PostingListBuilder::new(true);
4319        for doc_id in 0..(BLOCK_SIZE + 5) as u32 {
4320            builder.add(
4321                doc_id,
4322                PositionRecorder::Position(smallvec::smallvec![1, 3, 5]),
4323            );
4324        }
4325
4326        assert_eq!(builder.size(), tracked_memory_size(&builder));
4327    }
4328
4329    #[test]
4330    fn test_posting_builder_flush_releases_tail_position_capacity() {
4331        let mut builder = PostingListBuilder::new(true);
4332        let positions = smallvec::SmallVec::<[u32; 2]>::from_vec((0..1024).collect());
4333        for doc_id in 0..BLOCK_SIZE as u32 {
4334            builder.add(doc_id, PositionRecorder::Position(positions.clone()));
4335        }
4336
4337        assert_eq!(builder.tail_positions.size(), 0);
4338        assert_eq!(builder.size(), {
4339            let encoded_blocks_size = builder
4340                .encoded_blocks
4341                .iter()
4342                .map(|encoded_blocks| std::mem::size_of::<EncodedBlocks>() + encoded_blocks.size())
4343                .sum::<usize>();
4344            let encoded_positions_size = builder
4345                .encoded_position_blocks
4346                .as_ref()
4347                .map(|positions| std::mem::size_of::<EncodedPositionBlocks>() + positions.size())
4348                .unwrap_or(0usize);
4349            (encoded_blocks_size
4350                + builder.tail_entries.capacity() * std::mem::size_of::<RawDocInfo>()
4351                + builder.tail_positions.size()
4352                + encoded_positions_size) as u64
4353        });
4354    }
4355
4356    #[test]
4357    fn test_posting_builder_streamed_positions_roundtrip() {
4358        let mut builder = PostingListBuilder::new(true);
4359        assert!(builder.add_occurrence(0, 1).unwrap());
4360        assert!(!builder.add_occurrence(0, 4).unwrap());
4361        assert!(!builder.add_occurrence(0, 9).unwrap());
4362        builder.finish_open_doc(0).unwrap();
4363
4364        assert!(builder.add_occurrence(2, 3).unwrap());
4365        builder.finish_open_doc(2).unwrap();
4366
4367        let entries = builder.iter().collect::<Vec<_>>();
4368        assert_eq!(
4369            entries,
4370            vec![
4371                (0_u32, 3_u32, Some(vec![1_u32, 4_u32, 9_u32])),
4372                (2_u32, 1_u32, Some(vec![3_u32])),
4373            ]
4374        );
4375    }
4376
4377    #[test]
4378    fn test_posting_builder_roundtrip_shared_positions() {
4379        let entries = vec![
4380            (0_u32, vec![1_u32, 5]),
4381            (2, vec![0, 4, 9]),
4382            (4, vec![7]),
4383            (8, vec![3, 10]),
4384            (13, vec![2, 11, 30]),
4385        ];
4386        let mut builder =
4387            PostingListBuilder::new_with_posting_tail_codec(true, PostingTailCodec::VarintDelta);
4388        for (doc_id, positions) in &entries {
4389            builder.add(
4390                *doc_id,
4391                PositionRecorder::Position(positions.clone().into()),
4392            );
4393        }
4394
4395        let batch = builder.to_batch(vec![1.0]).unwrap();
4396        assert!(batch.column_by_name(COMPRESSED_POSITION_COL).is_some());
4397        assert!(batch.column_by_name(POSITION_COL).is_none());
4398        assert_eq!(
4399            batch.schema_ref().metadata().get(POSTING_TAIL_CODEC_KEY),
4400            Some(&PostingTailCodec::VarintDelta.as_str().to_owned())
4401        );
4402        assert_eq!(
4403            batch.schema_ref().metadata().get(POSITIONS_LAYOUT_KEY),
4404            Some(&POSITIONS_LAYOUT_SHARED_STREAM_V2.to_owned())
4405        );
4406        assert_eq!(
4407            batch.schema_ref().metadata().get(POSITIONS_CODEC_KEY),
4408            Some(&PositionStreamCodec::PackedDelta.as_str().to_owned())
4409        );
4410
4411        let posting =
4412            PostingList::from_batch(&batch, Some(1.0), Some(entries.len() as u32)).unwrap();
4413        let actual = posting
4414            .iter()
4415            .map(|(doc_id, freq, positions)| {
4416                (doc_id as u32, freq, positions.unwrap().collect::<Vec<_>>())
4417            })
4418            .collect::<Vec<_>>();
4419        let expected = entries
4420            .iter()
4421            .map(|(doc_id, positions)| (*doc_id, positions.len() as u32, positions.clone()))
4422            .collect::<Vec<_>>();
4423        assert_eq!(actual, expected);
4424    }
4425
4426    #[test]
4427    fn test_posting_builder_roundtrip_legacy_positions() {
4428        let entries = vec![(0_u32, vec![1_u32, 5]), (2, vec![0, 4, 9]), (4, vec![7])];
4429        let mut builder =
4430            PostingListBuilder::new_with_posting_tail_codec(true, PostingTailCodec::Fixed32);
4431        for (doc_id, positions) in &entries {
4432            builder.add(
4433                *doc_id,
4434                PositionRecorder::Position(positions.clone().into()),
4435            );
4436        }
4437
4438        let batch = builder.to_batch(vec![1.0]).unwrap();
4439        assert!(batch.column_by_name(POSITION_COL).is_some());
4440        assert!(batch.column_by_name(COMPRESSED_POSITION_COL).is_none());
4441        assert_eq!(
4442            batch.schema_ref().metadata().get(POSTING_TAIL_CODEC_KEY),
4443            None
4444        );
4445        assert_eq!(
4446            batch.schema_ref().metadata().get(POSITIONS_LAYOUT_KEY),
4447            None
4448        );
4449        assert_eq!(batch.schema_ref().metadata().get(POSITIONS_CODEC_KEY), None);
4450
4451        let posting =
4452            PostingList::from_batch(&batch, Some(1.0), Some(entries.len() as u32)).unwrap();
4453        let actual = posting
4454            .iter()
4455            .map(|(doc_id, freq, positions)| {
4456                (doc_id as u32, freq, positions.unwrap().collect::<Vec<_>>())
4457            })
4458            .collect::<Vec<_>>();
4459        let expected = entries
4460            .iter()
4461            .map(|(doc_id, positions)| (*doc_id, positions.len() as u32, positions.clone()))
4462            .collect::<Vec<_>>();
4463        assert_eq!(actual, expected);
4464    }
4465
4466    #[test]
4467    fn test_resolve_fts_format_version_defaults_to_v1() {
4468        assert_eq!(
4469            resolve_fts_format_version(None).unwrap(),
4470            InvertedListFormatVersion::V1
4471        );
4472        assert_eq!(
4473            resolve_fts_format_version(Some("2")).unwrap(),
4474            InvertedListFormatVersion::V2
4475        );
4476    }
4477
4478    #[test]
4479    fn test_legacy_compressed_positions_still_readable() {
4480        let doc_ids = [1_u32, 3_u32];
4481        let frequencies = [2_u32, 3_u32];
4482        let posting = compress_posting_list_with_tail_codec(
4483            doc_ids.len(),
4484            doc_ids.iter(),
4485            frequencies.iter(),
4486            std::iter::once(1.0_f32),
4487            PostingTailCodec::Fixed32,
4488        )
4489        .unwrap();
4490
4491        let mut posting_builder = ListBuilder::new(LargeBinaryBuilder::new());
4492        for idx in 0..posting.len() {
4493            posting_builder.values().append_value(posting.value(idx));
4494        }
4495        posting_builder.append(true);
4496
4497        let mut positions_builder = ListBuilder::new(ListBuilder::new(LargeBinaryBuilder::new()));
4498        for positions in [vec![1_u32, 5_u32], vec![0_u32, 4_u32, 9_u32]] {
4499            let compressed = compress_positions(&positions).unwrap();
4500            let doc_builder = positions_builder.values();
4501            for idx in 0..compressed.len() {
4502                doc_builder.values().append_value(compressed.value(idx));
4503            }
4504            doc_builder.append(true);
4505        }
4506        positions_builder.append(true);
4507
4508        let schema = Arc::new(Schema::new(vec![
4509            Field::new(
4510                POSTING_COL,
4511                DataType::List(Arc::new(Field::new("item", DataType::LargeBinary, true))),
4512                false,
4513            ),
4514            Field::new(MAX_SCORE_COL, DataType::Float32, false),
4515            Field::new(LENGTH_COL, DataType::UInt32, false),
4516            Field::new(
4517                POSITION_COL,
4518                DataType::List(Arc::new(Field::new(
4519                    "item",
4520                    DataType::List(Arc::new(Field::new("item", DataType::LargeBinary, true))),
4521                    true,
4522                ))),
4523                false,
4524            ),
4525        ]));
4526        let batch = RecordBatch::try_new(
4527            schema,
4528            vec![
4529                Arc::new(posting_builder.finish()) as ArrayRef,
4530                Arc::new(Float32Array::from(vec![1.0])) as ArrayRef,
4531                Arc::new(UInt32Array::from(vec![doc_ids.len() as u32])) as ArrayRef,
4532                Arc::new(positions_builder.finish()) as ArrayRef,
4533            ],
4534        )
4535        .unwrap();
4536
4537        let posting =
4538            PostingList::from_batch(&batch, Some(1.0), Some(doc_ids.len() as u32)).unwrap();
4539        let actual = posting
4540            .iter()
4541            .map(|(doc_id, freq, positions)| {
4542                (doc_id as u32, freq, positions.unwrap().collect::<Vec<_>>())
4543            })
4544            .collect::<Vec<_>>();
4545        assert_eq!(actual, vec![(1, 2, vec![1, 5]), (3, 3, vec![0, 4, 9]),]);
4546    }
4547
4548    #[test]
4549    fn test_shared_stream_v2_without_codec_still_readable() {
4550        let doc_ids = [1_u32, 3_u32];
4551        let frequencies = [2_u32, 3_u32];
4552        let posting = compress_posting_list_with_tail_codec(
4553            doc_ids.len(),
4554            doc_ids.iter(),
4555            frequencies.iter(),
4556            std::iter::once(1.0_f32),
4557            PostingTailCodec::Fixed32,
4558        )
4559        .unwrap();
4560
4561        let mut posting_builder = ListBuilder::new(LargeBinaryBuilder::new());
4562        for idx in 0..posting.len() {
4563            posting_builder.values().append_value(posting.value(idx));
4564        }
4565        posting_builder.append(true);
4566
4567        let positions = vec![1_u32, 5_u32, 0_u32, 4_u32, 9_u32];
4568        let mut encoded_positions = Vec::new();
4569        encode_position_stream_block_into(
4570            &positions,
4571            &frequencies,
4572            PositionStreamCodec::VarintDocDelta,
4573            &mut encoded_positions,
4574        )
4575        .unwrap();
4576
4577        let mut position_offsets = ListBuilder::new(UInt32Builder::new());
4578        position_offsets.values().append_value(0);
4579        position_offsets.append(true);
4580
4581        let schema = Arc::new(Schema::new_with_metadata(
4582            vec![
4583                Field::new(
4584                    POSTING_COL,
4585                    DataType::List(Arc::new(Field::new("item", DataType::LargeBinary, true))),
4586                    false,
4587                ),
4588                Field::new(MAX_SCORE_COL, DataType::Float32, false),
4589                Field::new(LENGTH_COL, DataType::UInt32, false),
4590                Field::new(COMPRESSED_POSITION_COL, DataType::LargeBinary, false),
4591                Field::new(
4592                    POSITION_BLOCK_OFFSET_COL,
4593                    DataType::List(Arc::new(Field::new("item", DataType::UInt32, true))),
4594                    false,
4595                ),
4596            ],
4597            HashMap::from([(
4598                POSITIONS_LAYOUT_KEY.to_owned(),
4599                POSITIONS_LAYOUT_SHARED_STREAM_V2.to_owned(),
4600            )]),
4601        ));
4602        let batch = RecordBatch::try_new(
4603            schema,
4604            vec![
4605                Arc::new(posting_builder.finish()) as ArrayRef,
4606                Arc::new(Float32Array::from(vec![1.0])) as ArrayRef,
4607                Arc::new(UInt32Array::from(vec![doc_ids.len() as u32])) as ArrayRef,
4608                Arc::new(arrow_array::LargeBinaryArray::from(vec![Some(
4609                    encoded_positions.as_slice(),
4610                )])) as ArrayRef,
4611                Arc::new(position_offsets.finish()) as ArrayRef,
4612            ],
4613        )
4614        .unwrap();
4615
4616        let posting =
4617            PostingList::from_batch(&batch, Some(1.0), Some(doc_ids.len() as u32)).unwrap();
4618        let actual = posting
4619            .iter()
4620            .map(|(doc_id, freq, positions)| {
4621                (doc_id as u32, freq, positions.unwrap().collect::<Vec<_>>())
4622            })
4623            .collect::<Vec<_>>();
4624        assert_eq!(actual, vec![(1, 2, vec![1, 5]), (3, 3, vec![0, 4, 9]),]);
4625    }
4626
4627    #[test]
4628    fn test_shared_position_stream_is_smaller_for_sparse_positions() {
4629        let mut builder =
4630            PostingListBuilder::new_with_posting_tail_codec(true, PostingTailCodec::VarintDelta);
4631        let mut legacy_positions = Vec::with_capacity(BLOCK_SIZE * 4);
4632        for doc_id in 0..(BLOCK_SIZE * 4) as u32 {
4633            let mut positions = vec![doc_id * 3 + 1];
4634            if doc_id % 8 == 0 {
4635                positions.push(doc_id * 3 + 2);
4636            }
4637            builder.add(doc_id, PositionRecorder::Position(positions.clone().into()));
4638            legacy_positions.push(positions);
4639        }
4640
4641        let batch = builder.to_batch(vec![1.0; 4]).unwrap();
4642        let shared_positions_size = batch[COMPRESSED_POSITION_COL].get_buffer_memory_size()
4643            + batch[POSITION_BLOCK_OFFSET_COL].get_buffer_memory_size();
4644
4645        let mut positions_builder = ListBuilder::new(ListBuilder::new(LargeBinaryBuilder::new()));
4646        for positions in legacy_positions {
4647            let compressed = compress_positions(&positions).unwrap();
4648            let doc_builder = positions_builder.values();
4649            for idx in 0..compressed.len() {
4650                doc_builder.values().append_value(compressed.value(idx));
4651            }
4652            doc_builder.append(true);
4653        }
4654        positions_builder.append(true);
4655        let legacy_positions_size = positions_builder.finish().get_buffer_memory_size();
4656
4657        assert!(
4658            shared_positions_size < legacy_positions_size,
4659            "expected shared position stream to be smaller than legacy per-doc storage, shared={shared_positions_size}, legacy={legacy_positions_size}",
4660        );
4661    }
4662
4663    #[test]
4664    fn test_posting_list_batch_matches_docset_scoring() {
4665        let mut docs = DocSet::default();
4666        let num_docs = BLOCK_SIZE + 3;
4667        for doc_id in 0..num_docs as u32 {
4668            docs.append(doc_id as u64, doc_id % 7 + 1);
4669        }
4670
4671        let doc_ids = (0..num_docs as u32).collect::<Vec<_>>();
4672        let freqs = doc_ids
4673            .iter()
4674            .map(|doc_id| doc_id % 5 + 1)
4675            .collect::<Vec<_>>();
4676
4677        let mut builder_scores = PostingListBuilder::new(false);
4678        let mut builder_docs = PostingListBuilder::new(false);
4679        for (&doc_id, &freq) in doc_ids.iter().zip(freqs.iter()) {
4680            builder_scores.add(doc_id, PositionRecorder::Count(freq));
4681            builder_docs.add(doc_id, PositionRecorder::Count(freq));
4682        }
4683
4684        let block_max_scores = docs.calculate_block_max_scores(doc_ids.iter(), freqs.iter());
4685        let batch_scores = builder_scores.to_batch(block_max_scores).unwrap();
4686        let batch_docs = builder_docs
4687            .to_batch_with_docs(&docs, inverted_list_schema(false))
4688            .unwrap();
4689
4690        let scores_posting = batch_scores[POSTING_COL].as_list::<i32>().value(0);
4691        let scores_posting = scores_posting.as_binary::<i64>();
4692        let docs_posting = batch_docs[POSTING_COL].as_list::<i32>().value(0);
4693        let docs_posting = docs_posting.as_binary::<i64>();
4694        assert_eq!(scores_posting, docs_posting);
4695
4696        let score_left = batch_scores[MAX_SCORE_COL]
4697            .as_primitive::<Float32Type>()
4698            .value(0);
4699        let score_right = batch_docs[MAX_SCORE_COL]
4700            .as_primitive::<Float32Type>()
4701            .value(0);
4702        assert!((score_left - score_right).abs() < 1e-6);
4703
4704        let len_left = batch_scores[LENGTH_COL]
4705            .as_primitive::<UInt32Type>()
4706            .value(0);
4707        let len_right = batch_docs[LENGTH_COL].as_primitive::<UInt32Type>().value(0);
4708        assert_eq!(len_left, len_right);
4709    }
4710
4711    #[tokio::test]
4712    async fn test_remap_to_empty_posting_list() {
4713        let tmpdir = TempObjDir::default();
4714        let store = Arc::new(LanceIndexStore::new(
4715            ObjectStore::local().into(),
4716            tmpdir.clone(),
4717            Arc::new(LanceCache::no_cache()),
4718        ));
4719
4720        let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default());
4721
4722        // index of docs:
4723        // 0: lance
4724        // 1: lake lake
4725        // 2: lake lake lake
4726        builder.tokens.add("lance".to_owned());
4727        builder.tokens.add("lake".to_owned());
4728        builder.posting_lists.push(PostingListBuilder::new(false));
4729        builder.posting_lists.push(PostingListBuilder::new(false));
4730        builder.posting_lists[0].add(0, PositionRecorder::Count(1));
4731        builder.posting_lists[1].add(1, PositionRecorder::Count(2));
4732        builder.posting_lists[1].add(2, PositionRecorder::Count(3));
4733        builder.docs.append(0, 1);
4734        builder.docs.append(1, 1);
4735        builder.docs.append(2, 1);
4736        builder.write(store.as_ref()).await.unwrap();
4737
4738        let index = InvertedPartition::load(
4739            store.clone(),
4740            0,
4741            None,
4742            &LanceCache::no_cache(),
4743            TokenSetFormat::default(),
4744        )
4745        .await
4746        .unwrap();
4747        let mut builder = index.into_builder().await.unwrap();
4748
4749        let mapping = HashMap::from([(0, None), (2, Some(3))]);
4750        builder.remap(&mapping).await.unwrap();
4751
4752        // after remap, the doc 0 is removed, and the doc 2 is updated to 3
4753        assert_eq!(builder.tokens.len(), 1);
4754        assert_eq!(builder.tokens.get("lake"), Some(0));
4755        assert_eq!(builder.posting_lists.len(), 1);
4756        assert_eq!(builder.posting_lists[0].len(), 2);
4757        assert_eq!(builder.docs.len(), 2);
4758        assert_eq!(builder.docs.row_id(0), 1);
4759        assert_eq!(builder.docs.row_id(1), 3);
4760
4761        builder.write(store.as_ref()).await.unwrap();
4762
4763        // remap to delete all docs
4764        let mapping = HashMap::from([(1, None), (3, None)]);
4765        builder.remap(&mapping).await.unwrap();
4766
4767        assert_eq!(builder.tokens.len(), 0);
4768        assert_eq!(builder.posting_lists.len(), 0);
4769        assert_eq!(builder.docs.len(), 0);
4770
4771        builder.write(store.as_ref()).await.unwrap();
4772    }
4773
4774    #[tokio::test]
4775    async fn test_posting_cache_conflict_across_partitions() {
4776        let tmpdir = TempObjDir::default();
4777        let store = Arc::new(LanceIndexStore::new(
4778            ObjectStore::local().into(),
4779            tmpdir.clone(),
4780            Arc::new(LanceCache::no_cache()),
4781        ));
4782
4783        // Create first partition with one token and posting list length 1
4784        let mut builder1 = InnerBuilder::new(0, false, TokenSetFormat::default());
4785        builder1.tokens.add("test".to_owned());
4786        builder1.posting_lists.push(PostingListBuilder::new(false));
4787        builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
4788        builder1.docs.append(100, 1); // row_id=100, num_tokens=1
4789        builder1.write(store.as_ref()).await.unwrap();
4790
4791        // Create second partition with one token and posting list length 4
4792        let mut builder2 = InnerBuilder::new(1, false, TokenSetFormat::default());
4793        builder2.tokens.add("test".to_owned()); // Use same token to test cache prefix fix
4794        builder2.posting_lists.push(PostingListBuilder::new(false));
4795        builder2.posting_lists[0].add(0, PositionRecorder::Count(2));
4796        builder2.posting_lists[0].add(1, PositionRecorder::Count(1));
4797        builder2.posting_lists[0].add(2, PositionRecorder::Count(3));
4798        builder2.posting_lists[0].add(3, PositionRecorder::Count(1));
4799        builder2.docs.append(200, 2); // row_id=200, num_tokens=2
4800        builder2.docs.append(201, 1); // row_id=201, num_tokens=1
4801        builder2.docs.append(202, 3); // row_id=202, num_tokens=3
4802        builder2.docs.append(203, 1); // row_id=203, num_tokens=1
4803        builder2.write(store.as_ref()).await.unwrap();
4804
4805        // Create metadata file with both partitions
4806        let metadata = std::collections::HashMap::from_iter(vec![
4807            (
4808                "partitions".to_owned(),
4809                serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
4810            ),
4811            (
4812                "params".to_owned(),
4813                serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
4814            ),
4815            (
4816                TOKEN_SET_FORMAT_KEY.to_owned(),
4817                TokenSetFormat::default().to_string(),
4818            ),
4819        ]);
4820        let mut writer = store
4821            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
4822            .await
4823            .unwrap();
4824        writer.finish_with_metadata(metadata).await.unwrap();
4825
4826        // Load the inverted index
4827        let cache = Arc::new(LanceCache::with_capacity(4096));
4828        let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
4829            .await
4830            .unwrap();
4831
4832        // Verify the index structure
4833        assert_eq!(index.partitions.len(), 2);
4834        assert_eq!(index.partitions[0].tokens.len(), 1);
4835        assert_eq!(index.partitions[1].tokens.len(), 1);
4836
4837        // Verify the partitions were loaded correctly
4838
4839        // Verify posting list lengths (note: partition order may differ from creation order)
4840        // Verify based on actual loading order
4841        if index.partitions[0].id() == 0 {
4842            // If partition[0] is ID=0, then it should have 1 document
4843            assert_eq!(index.partitions[0].inverted_list.posting_len(0), 1);
4844            assert_eq!(index.partitions[1].inverted_list.posting_len(0), 4);
4845            assert_eq!(index.partitions[0].docs.len(), 1);
4846            assert_eq!(index.partitions[1].docs.len(), 4);
4847        } else {
4848            // If partition[0] is ID=1, then it should have 4 documents
4849            assert_eq!(index.partitions[0].inverted_list.posting_len(0), 4);
4850            assert_eq!(index.partitions[1].inverted_list.posting_len(0), 1);
4851            assert_eq!(index.partitions[0].docs.len(), 4);
4852            assert_eq!(index.partitions[1].docs.len(), 1);
4853        }
4854
4855        // Prewarm the inverted index (this loads posting lists into cache)
4856        index.prewarm().await.unwrap();
4857
4858        let tokens = Arc::new(Tokens::new(vec!["test".to_string()], DocType::Text));
4859        let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
4860        let prefilter = Arc::new(NoFilter);
4861        let metrics = Arc::new(NoOpMetricsCollector);
4862
4863        let (row_ids, scores) = index
4864            .bm25_search(tokens, params, Operator::Or, prefilter, metrics, None)
4865            .await
4866            .unwrap();
4867
4868        // Verify that we got search results
4869        // Expected to find 5 documents: 1 from first partition, 4 from second partition
4870        assert_eq!(row_ids.len(), 5, "row_ids: {:?}", row_ids);
4871        assert!(!row_ids.is_empty(), "Should find at least some documents");
4872        assert_eq!(row_ids.len(), scores.len());
4873
4874        // All scores should be positive since all documents contain the search token
4875        for &score in &scores {
4876            assert!(score > 0.0, "All scores should be positive");
4877        }
4878
4879        // Check that we got results from both partitions
4880        assert!(
4881            row_ids.contains(&100),
4882            "Should contain row_id from partition 0"
4883        );
4884        assert!(
4885            row_ids.iter().any(|&id| id >= 200),
4886            "Should contain row_id from partition 1"
4887        );
4888    }
4889
4890    #[tokio::test]
4891    async fn test_modern_prewarm_shrinks_cached_posting_buffers() {
4892        let tmpdir = TempObjDir::default();
4893        let store = Arc::new(LanceIndexStore::new(
4894            ObjectStore::local().into(),
4895            tmpdir.clone(),
4896            Arc::new(LanceCache::no_cache()),
4897        ));
4898
4899        let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default());
4900        builder.tokens.add("alpha".to_owned());
4901        builder.tokens.add("beta".to_owned());
4902        builder.posting_lists.push(PostingListBuilder::new(false));
4903        builder.posting_lists.push(PostingListBuilder::new(false));
4904        builder.posting_lists[0].add(0, PositionRecorder::Count(1));
4905        builder.posting_lists[0].add(1, PositionRecorder::Count(2));
4906        builder.posting_lists[1].add(2, PositionRecorder::Count(3));
4907        builder.posting_lists[1].add(3, PositionRecorder::Count(4));
4908        builder.docs.append(100, 1);
4909        builder.docs.append(101, 2);
4910        builder.docs.append(102, 3);
4911        builder.docs.append(103, 4);
4912        builder.write(store.as_ref()).await.unwrap();
4913
4914        let metadata = std::collections::HashMap::from_iter(vec![
4915            (
4916                "partitions".to_owned(),
4917                serde_json::to_string(&vec![0u64]).unwrap(),
4918            ),
4919            (
4920                "params".to_owned(),
4921                serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
4922            ),
4923            (
4924                TOKEN_SET_FORMAT_KEY.to_owned(),
4925                TokenSetFormat::default().to_string(),
4926            ),
4927        ]);
4928        let mut writer = store
4929            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
4930            .await
4931            .unwrap();
4932        writer.finish_with_metadata(metadata).await.unwrap();
4933
4934        let cache = Arc::new(LanceCache::with_capacity(4096));
4935        let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
4936            .await
4937            .unwrap();
4938        let inverted_list = &index.partitions[0].inverted_list;
4939        assert!(
4940            inverted_list.offsets.is_none(),
4941            "test should use modern posting layout"
4942        );
4943
4944        inverted_list.prewarm_posting_lists(false).await.unwrap();
4945
4946        let alpha = inverted_list
4947            .index_cache
4948            .get_with_key(&PostingListKey { token_id: 0 })
4949            .await
4950            .unwrap();
4951        let beta = inverted_list
4952            .index_cache
4953            .get_with_key(&PostingListKey { token_id: 1 })
4954            .await
4955            .unwrap();
4956
4957        let PostingList::Compressed(alpha) = alpha.as_ref() else {
4958            panic!("expected compressed posting list for token 0");
4959        };
4960        let PostingList::Compressed(beta) = beta.as_ref() else {
4961            panic!("expected compressed posting list for token 1");
4962        };
4963
4964        assert_ne!(
4965            alpha.blocks.values().as_ptr(),
4966            beta.blocks.values().as_ptr(),
4967            "prewarm should not leave cached posting lists sharing the same values buffer"
4968        );
4969    }
4970
4971    #[tokio::test]
4972    async fn test_prewarm_with_positions_populates_separate_position_cache() {
4973        let tmpdir = TempObjDir::default();
4974        let store = Arc::new(LanceIndexStore::new(
4975            ObjectStore::local().into(),
4976            tmpdir.clone(),
4977            Arc::new(LanceCache::no_cache()),
4978        ));
4979
4980        let mut builder = InnerBuilder::new_with_format_version(
4981            0,
4982            true,
4983            TokenSetFormat::default(),
4984            InvertedListFormatVersion::V1,
4985        );
4986        builder.tokens.add("hello".to_owned());
4987        builder.tokens.add("world".to_owned());
4988        builder
4989            .posting_lists
4990            .push(PostingListBuilder::new_with_posting_tail_codec(
4991                true,
4992                PostingTailCodec::Fixed32,
4993            ));
4994        builder
4995            .posting_lists
4996            .push(PostingListBuilder::new_with_posting_tail_codec(
4997                true,
4998                PostingTailCodec::Fixed32,
4999            ));
5000        builder.posting_lists[0].add(0, PositionRecorder::Position(vec![0].into()));
5001        builder.posting_lists[1].add(0, PositionRecorder::Position(vec![1].into()));
5002        builder.posting_lists[0].add(1, PositionRecorder::Position(vec![0].into()));
5003        builder.posting_lists[1].add(1, PositionRecorder::Position(vec![2].into()));
5004        builder.docs.append(100, 2);
5005        builder.docs.append(101, 2);
5006        builder.write(store.as_ref()).await.unwrap();
5007
5008        let metadata = std::collections::HashMap::from_iter(vec![
5009            (
5010                "partitions".to_owned(),
5011                serde_json::to_string(&vec![0_u64]).unwrap(),
5012            ),
5013            (
5014                "params".to_owned(),
5015                serde_json::to_string(&InvertedIndexParams::default().with_position(true)).unwrap(),
5016            ),
5017            (
5018                TOKEN_SET_FORMAT_KEY.to_owned(),
5019                TokenSetFormat::default().to_string(),
5020            ),
5021        ]);
5022        let mut writer = store
5023            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5024            .await
5025            .unwrap();
5026        writer.finish_with_metadata(metadata).await.unwrap();
5027
5028        let cache = Arc::new(LanceCache::with_capacity(4096));
5029        let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
5030            .await
5031            .unwrap();
5032
5033        index
5034            .prewarm_with_options(&FtsPrewarmOptions::new().with_position(true))
5035            .await
5036            .unwrap();
5037
5038        let inverted_list = &index.partitions[0].inverted_list;
5039        let posting = inverted_list
5040            .index_cache
5041            .get_with_key(&PostingListKey { token_id: 0 })
5042            .await
5043            .unwrap();
5044        assert!(
5045            !posting.has_position(),
5046            "posting cache should remain positions-free after prewarm"
5047        );
5048
5049        let positions = inverted_list
5050            .index_cache
5051            .get_with_key(&PositionKey { token_id: 0 })
5052            .await
5053            .unwrap();
5054        assert!(
5055            matches!(
5056                positions.as_ref().0,
5057                CompressedPositionStorage::LegacyPerDoc(_)
5058            ),
5059            "positions should be stored in the dedicated position cache"
5060        );
5061    }
5062
5063    #[tokio::test]
5064    async fn test_prewarm_with_v2_positions_preserves_shared_stream_codec() {
5065        let tmpdir = TempObjDir::default();
5066        let store = Arc::new(LanceIndexStore::new(
5067            ObjectStore::local().into(),
5068            tmpdir.clone(),
5069            Arc::new(LanceCache::no_cache()),
5070        ));
5071
5072        let format_version = InvertedListFormatVersion::V2;
5073        let posting_tail_codec = format_version.posting_tail_codec();
5074        let mut builder = InnerBuilder::new_with_format_version(
5075            0,
5076            true,
5077            TokenSetFormat::default(),
5078            format_version,
5079        );
5080        builder.tokens.add("body".to_owned());
5081
5082        let mut posting_list =
5083            PostingListBuilder::new_with_posting_tail_codec(true, posting_tail_codec);
5084        let expected = (0..(BLOCK_SIZE + 5) as u32)
5085            .map(|doc_id| {
5086                let positions = vec![doc_id % 3, doc_id % 3 + 2, doc_id % 3 + 5];
5087                posting_list.add(doc_id, PositionRecorder::Position(positions.clone().into()));
5088                builder.docs.append(30_000 + doc_id as u64, 20 + doc_id % 7);
5089                (doc_id, positions.len() as u32, positions)
5090            })
5091            .collect::<Vec<_>>();
5092        builder.posting_lists.push(posting_list);
5093        builder.write(store.as_ref()).await.unwrap();
5094
5095        let metadata = HashMap::from([
5096            (
5097                "partitions".to_owned(),
5098                serde_json::to_string(&vec![0_u64]).unwrap(),
5099            ),
5100            (
5101                "params".to_owned(),
5102                serde_json::to_string(&InvertedIndexParams::default().with_position(true)).unwrap(),
5103            ),
5104            (
5105                TOKEN_SET_FORMAT_KEY.to_owned(),
5106                TokenSetFormat::default().to_string(),
5107            ),
5108            (
5109                POSTING_TAIL_CODEC_KEY.to_owned(),
5110                posting_tail_codec.as_str().to_owned(),
5111            ),
5112            (
5113                POSITIONS_LAYOUT_KEY.to_owned(),
5114                POSITIONS_LAYOUT_SHARED_STREAM_V2.to_owned(),
5115            ),
5116            (
5117                POSITIONS_CODEC_KEY.to_owned(),
5118                PositionStreamCodec::PackedDelta.as_str().to_owned(),
5119            ),
5120        ]);
5121        let mut writer = store
5122            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5123            .await
5124            .unwrap();
5125        writer.finish_with_metadata(metadata).await.unwrap();
5126
5127        let cache = Arc::new(LanceCache::with_capacity(4096));
5128        let index = InvertedIndex::load(store, None, cache.as_ref())
5129            .await
5130            .unwrap();
5131        index
5132            .prewarm_with_options(&FtsPrewarmOptions::new().with_position(true))
5133            .await
5134            .unwrap();
5135
5136        let actual = index.partitions[0]
5137            .inverted_list
5138            .posting_list(0, true, &NoOpMetricsCollector)
5139            .await
5140            .unwrap()
5141            .iter()
5142            .map(|(doc_id, freq, positions)| {
5143                (doc_id as u32, freq, positions.unwrap().collect::<Vec<_>>())
5144            })
5145            .collect::<Vec<_>>();
5146
5147        assert_eq!(actual, expected);
5148    }
5149
5150    #[test]
5151    fn test_block_max_scores_capacity_matches_block_count() {
5152        let mut docs = DocSet::default();
5153        let num_docs = BLOCK_SIZE * 3 + 7;
5154        let doc_ids = (0..num_docs as u32).collect::<Vec<_>>();
5155        for doc_id in &doc_ids {
5156            docs.append(*doc_id as u64, 1);
5157        }
5158
5159        let freqs = vec![1_u32; doc_ids.len()];
5160        let block_max_scores = docs.calculate_block_max_scores(doc_ids.iter(), freqs.iter());
5161        let expected_blocks = doc_ids.len().div_ceil(BLOCK_SIZE);
5162
5163        assert_eq!(block_max_scores.len(), expected_blocks);
5164        assert_eq!(block_max_scores.capacity(), expected_blocks);
5165    }
5166
5167    #[tokio::test]
5168    async fn test_bm25_search_uses_global_idf() {
5169        let tmpdir = TempObjDir::default();
5170        let store = Arc::new(LanceIndexStore::new(
5171            ObjectStore::local().into(),
5172            tmpdir.clone(),
5173            Arc::new(LanceCache::no_cache()),
5174        ));
5175
5176        // Partition 0: 3 docs, only one contains "alpha".
5177        let mut builder0 = InnerBuilder::new(0, false, TokenSetFormat::default());
5178        builder0.tokens.add("alpha".to_owned());
5179        builder0.tokens.add("beta".to_owned());
5180        builder0.posting_lists.push(PostingListBuilder::new(false));
5181        builder0.posting_lists.push(PostingListBuilder::new(false));
5182        builder0.posting_lists[0].add(0, PositionRecorder::Count(1));
5183        builder0.posting_lists[1].add(1, PositionRecorder::Count(1));
5184        builder0.posting_lists[1].add(2, PositionRecorder::Count(1));
5185        builder0.docs.append(100, 1);
5186        builder0.docs.append(101, 1);
5187        builder0.docs.append(102, 1);
5188        builder0.write(store.as_ref()).await.unwrap();
5189
5190        // Partition 1: 1 doc, contains "alpha".
5191        let mut builder1 = InnerBuilder::new(1, false, TokenSetFormat::default());
5192        builder1.tokens.add("alpha".to_owned());
5193        builder1.posting_lists.push(PostingListBuilder::new(false));
5194        builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
5195        builder1.docs.append(200, 1);
5196        builder1.write(store.as_ref()).await.unwrap();
5197
5198        let metadata = std::collections::HashMap::from_iter(vec![
5199            (
5200                "partitions".to_owned(),
5201                serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
5202            ),
5203            (
5204                "params".to_owned(),
5205                serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
5206            ),
5207            (
5208                TOKEN_SET_FORMAT_KEY.to_owned(),
5209                TokenSetFormat::default().to_string(),
5210            ),
5211        ]);
5212        let mut writer = store
5213            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5214            .await
5215            .unwrap();
5216        writer.finish_with_metadata(metadata).await.unwrap();
5217
5218        let cache = Arc::new(LanceCache::with_capacity(4096));
5219        let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
5220            .await
5221            .unwrap();
5222
5223        let tokens = Arc::new(Tokens::new(vec!["alpha".to_string()], DocType::Text));
5224        let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
5225        let prefilter = Arc::new(NoFilter);
5226        let metrics = Arc::new(NoOpMetricsCollector);
5227
5228        let (row_ids, scores) = index
5229            .bm25_search(tokens, params, Operator::Or, prefilter, metrics, None)
5230            .await
5231            .unwrap();
5232
5233        assert_eq!(row_ids.len(), 2);
5234        assert!(row_ids.contains(&100));
5235        assert!(row_ids.contains(&200));
5236        assert_eq!(row_ids.len(), scores.len());
5237
5238        let expected_idf = idf(2, 4);
5239        for score in scores {
5240            assert!(
5241                (score - expected_idf).abs() < 1e-6,
5242                "score: {}, expected: {}",
5243                score,
5244                expected_idf
5245            );
5246        }
5247    }
5248
5249    #[tokio::test]
5250    async fn test_phrase_query_reads_legacy_per_doc_positions() {
5251        let tmpdir = TempObjDir::default();
5252        let store = Arc::new(LanceIndexStore::new(
5253            ObjectStore::local().into(),
5254            tmpdir.clone(),
5255            Arc::new(LanceCache::no_cache()),
5256        ));
5257
5258        let mut builder = InnerBuilder::new_with_format_version(
5259            0,
5260            true,
5261            TokenSetFormat::default(),
5262            InvertedListFormatVersion::V1,
5263        );
5264        builder.tokens.add("hello".to_owned());
5265        builder.tokens.add("world".to_owned());
5266        builder
5267            .posting_lists
5268            .push(PostingListBuilder::new_with_posting_tail_codec(
5269                true,
5270                PostingTailCodec::Fixed32,
5271            ));
5272        builder
5273            .posting_lists
5274            .push(PostingListBuilder::new_with_posting_tail_codec(
5275                true,
5276                PostingTailCodec::Fixed32,
5277            ));
5278        builder.posting_lists[0].add(0, PositionRecorder::Position(vec![0].into()));
5279        builder.posting_lists[1].add(0, PositionRecorder::Position(vec![1].into()));
5280        builder.posting_lists[0].add(1, PositionRecorder::Position(vec![0].into()));
5281        builder.posting_lists[1].add(1, PositionRecorder::Position(vec![2].into()));
5282        builder.docs.append(100, 2);
5283        builder.docs.append(101, 2);
5284        builder.write(store.as_ref()).await.unwrap();
5285
5286        let metadata = std::collections::HashMap::from_iter(vec![
5287            (
5288                "partitions".to_owned(),
5289                serde_json::to_string(&vec![0_u64]).unwrap(),
5290            ),
5291            (
5292                "params".to_owned(),
5293                serde_json::to_string(&InvertedIndexParams::default().with_position(true)).unwrap(),
5294            ),
5295            (
5296                TOKEN_SET_FORMAT_KEY.to_owned(),
5297                TokenSetFormat::default().to_string(),
5298            ),
5299        ]);
5300        let mut writer = store
5301            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5302            .await
5303            .unwrap();
5304        writer.finish_with_metadata(metadata).await.unwrap();
5305
5306        let cache = Arc::new(LanceCache::with_capacity(4096));
5307        let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
5308            .await
5309            .unwrap();
5310
5311        let tokens = Arc::new(Tokens::new(
5312            vec!["hello".to_owned(), "world".to_owned()],
5313            DocType::Text,
5314        ));
5315        let params = Arc::new(
5316            FtsSearchParams::new()
5317                .with_limit(Some(10))
5318                .with_phrase_slop(Some(0)),
5319        );
5320        let prefilter = Arc::new(NoFilter);
5321        let metrics = Arc::new(NoOpMetricsCollector);
5322
5323        let (row_ids, _scores) = index
5324            .bm25_search(tokens, params, Operator::And, prefilter, metrics, None)
5325            .await
5326            .unwrap();
5327
5328        assert_eq!(row_ids, vec![100]);
5329    }
5330
5331    #[tokio::test]
5332    async fn test_update_preserves_loaded_v2_format_version() -> Result<()> {
5333        let src_dir = TempObjDir::default();
5334        let dest_dir = TempObjDir::default();
5335        let src_store = Arc::new(LanceIndexStore::new(
5336            ObjectStore::local().into(),
5337            src_dir.clone(),
5338            Arc::new(LanceCache::no_cache()),
5339        ));
5340        let dest_store = Arc::new(LanceIndexStore::new(
5341            ObjectStore::local().into(),
5342            dest_dir.clone(),
5343            Arc::new(LanceCache::no_cache()),
5344        ));
5345
5346        let format_version = InvertedListFormatVersion::V2;
5347        let posting_tail_codec = format_version.posting_tail_codec();
5348        let mut partition = InnerBuilder::new_with_format_version(
5349            0,
5350            false,
5351            TokenSetFormat::default(),
5352            format_version,
5353        );
5354        partition.tokens.add("hello".to_owned());
5355        let mut posting_list =
5356            PostingListBuilder::new_with_posting_tail_codec(false, posting_tail_codec);
5357        posting_list.add(0, PositionRecorder::Count(1));
5358        partition.posting_lists.push(posting_list);
5359        partition.docs.append(100, 1);
5360        partition.write(src_store.as_ref()).await?;
5361
5362        let metadata = HashMap::from([
5363            (
5364                "partitions".to_owned(),
5365                serde_json::to_string(&vec![0_u64]).unwrap(),
5366            ),
5367            (
5368                "params".to_owned(),
5369                serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
5370            ),
5371            (
5372                TOKEN_SET_FORMAT_KEY.to_owned(),
5373                TokenSetFormat::default().to_string(),
5374            ),
5375            (
5376                POSTING_TAIL_CODEC_KEY.to_owned(),
5377                posting_tail_codec.as_str().to_owned(),
5378            ),
5379        ]);
5380        let mut writer = src_store
5381            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5382            .await
5383            .unwrap();
5384        writer.finish_with_metadata(metadata).await.unwrap();
5385
5386        let index = InvertedIndex::load(src_store, None, &LanceCache::no_cache()).await?;
5387        assert_eq!(index.index_version(), format_version.index_version());
5388
5389        let schema = Arc::new(Schema::new(vec![
5390            Field::new("doc", DataType::Utf8, true),
5391            Field::new(ROW_ID, DataType::UInt64, false),
5392        ]));
5393        let docs = Arc::new(StringArray::from(vec![Some("hello again")]));
5394        let row_ids = Arc::new(UInt64Array::from(vec![101u64]));
5395        let batch = RecordBatch::try_new(schema.clone(), vec![docs, row_ids])?;
5396        let stream = RecordBatchStreamAdapter::new(schema, stream::iter(vec![Ok(batch)]));
5397        let created = index
5398            .update(Box::pin(stream), dest_store.as_ref(), None)
5399            .await?;
5400
5401        assert_eq!(created.index_version, format_version.index_version());
5402
5403        let updated = InvertedIndex::load(dest_store, None, &LanceCache::no_cache()).await?;
5404        assert_eq!(updated.index_version(), format_version.index_version());
5405        assert_eq!(updated.partitions.len(), 2);
5406        for partition in &updated.partitions {
5407            assert_eq!(
5408                partition.inverted_list.posting_tail_codec(),
5409                posting_tail_codec
5410            );
5411        }
5412
5413        Ok(())
5414    }
5415
5416    #[tokio::test]
5417    async fn test_modern_index_without_deleted_col_has_empty_bitmap() {
5418        // An index created before the deleted_fragments feature was added
5419        // will have a metadata file with num_rows=0 (no record batch data).
5420        // The load path should gracefully handle this with an empty bitmap.
5421        let tmpdir = TempObjDir::default();
5422        let store = Arc::new(LanceIndexStore::new(
5423            ObjectStore::local().into(),
5424            tmpdir.clone(),
5425            Arc::new(LanceCache::no_cache()),
5426        ));
5427
5428        let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default());
5429        builder.tokens.add("test".to_owned());
5430        builder.posting_lists.push(PostingListBuilder::new(false));
5431        builder.posting_lists[0].add(0, PositionRecorder::Count(1));
5432        builder.docs.append(100, 1);
5433        builder.write(store.as_ref()).await.unwrap();
5434
5435        // Write a metadata file WITHOUT the deleted_fragments column
5436        // (simulates an older index version)
5437        let metadata = std::collections::HashMap::from_iter(vec![
5438            (
5439                "partitions".to_owned(),
5440                serde_json::to_string(&vec![0u64]).unwrap(),
5441            ),
5442            (
5443                "params".to_owned(),
5444                serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
5445            ),
5446            (
5447                TOKEN_SET_FORMAT_KEY.to_owned(),
5448                TokenSetFormat::default().to_string(),
5449            ),
5450        ]);
5451        let mut writer = store
5452            .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5453            .await
5454            .unwrap();
5455        writer.finish_with_metadata(metadata).await.unwrap();
5456
5457        let index = InvertedIndex::load(store, None, &LanceCache::no_cache())
5458            .await
5459            .unwrap();
5460        assert!(
5461            index.deleted_fragments().is_empty(),
5462            "index without deleted_fragments column should have empty bitmap"
5463        );
5464    }
5465}