lance_index/scalar/inverted/
index.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::fmt::Debug;
5use std::sync::Arc;
6use std::{
7    cmp::{min, Reverse},
8    collections::BinaryHeap,
9    ops::RangeInclusive,
10};
11use std::{
12    collections::{HashMap, HashSet},
13    ops::Range,
14};
15
16use arrow::{
17    array::LargeBinaryBuilder,
18    datatypes::{self, Float32Type, Int32Type, UInt64Type},
19};
20use arrow::{
21    array::{AsArray, ListBuilder, StringBuilder, UInt32Builder},
22    buffer::OffsetBuffer,
23};
24use arrow::{buffer::ScalarBuffer, datatypes::UInt32Type};
25use arrow_array::{
26    Array, ArrayRef, BooleanArray, Float32Array, LargeBinaryArray, ListArray, OffsetSizeTrait,
27    RecordBatch, UInt32Array, UInt64Array,
28};
29use arrow_schema::{DataType, Field, Schema, SchemaRef};
30use async_trait::async_trait;
31use datafusion::execution::SendableRecordBatchStream;
32use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
33use datafusion_common::DataFusionError;
34use deepsize::DeepSizeOf;
35use fst::{Automaton, IntoStreamer, Streamer};
36use futures::{stream, StreamExt, TryStreamExt};
37use itertools::Itertools;
38use lance_arrow::{iter_str_array, RecordBatchExt};
39use lance_core::utils::{
40    mask::RowIdMask,
41    tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS},
42};
43use lance_core::{container::list::ExpLinkedList, utils::tokio::get_num_compute_intensive_cpus};
44use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD};
45use moka::future::Cache;
46use roaring::RoaringBitmap;
47use snafu::location;
48use std::sync::LazyLock;
49use tracing::{info, instrument};
50
51use super::{
52    builder::{
53        doc_file_path, inverted_list_schema, posting_file_path, token_file_path, ScoredDoc,
54        BLOCK_SIZE,
55    },
56    iter::PlainPostingListIterator,
57    query::*,
58    scorer::{idf, BM25Scorer, Scorer, B, K1},
59};
60use super::{
61    builder::{InnerBuilder, PositionRecorder},
62    encoding::compress_posting_list,
63    iter::CompressedPostingListIterator,
64};
65use super::{
66    encoding::compress_positions,
67    iter::{PostingListIterator, TokenIterator, TokenSource},
68};
69use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams};
70use crate::frag_reuse::FragReuseIndex;
71use crate::scalar::{
72    AnyQuery, IndexReader, IndexStore, MetricsCollector, SargableQuery, ScalarIndex, SearchResult,
73};
74use crate::Index;
75use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys};
76
77pub const TOKENS_FILE: &str = "tokens.lance";
78pub const INVERT_LIST_FILE: &str = "invert.lance";
79pub const DOCS_FILE: &str = "docs.lance";
80pub const METADATA_FILE: &str = "metadata.lance";
81
82pub const TOKEN_COL: &str = "_token";
83pub const TOKEN_ID_COL: &str = "_token_id";
84pub const FREQUENCY_COL: &str = "_frequency";
85pub const POSITION_COL: &str = "_position";
86pub const COMPRESSED_POSITION_COL: &str = "_compressed_position";
87pub const POSTING_COL: &str = "_posting";
88pub const MAX_SCORE_COL: &str = "_max_score";
89pub const LENGTH_COL: &str = "_length";
90pub const BLOCK_MAX_SCORE_COL: &str = "_block_max_score";
91pub const NUM_TOKEN_COL: &str = "_num_tokens";
92pub const SCORE_COL: &str = "_score";
93pub static SCORE_FIELD: LazyLock<Field> =
94    LazyLock::new(|| Field::new(SCORE_COL, DataType::Float32, true));
95pub static FTS_SCHEMA: LazyLock<SchemaRef> =
96    LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone(), SCORE_FIELD.clone()])));
97
98pub static CACHE_SIZE: LazyLock<usize> = LazyLock::new(|| {
99    std::env::var("LANCE_INVERTED_CACHE_SIZE")
100        .ok()
101        .and_then(|s| s.parse().ok())
102        .unwrap_or(512 * 1024 * 1024)
103});
104
105#[derive(Clone)]
106pub struct InvertedIndex {
107    params: InvertedIndexParams,
108    store: Arc<dyn IndexStore>,
109    tokenizer: tantivy::tokenizer::TextAnalyzer,
110    pub(crate) partitions: Vec<Arc<InvertedPartition>>,
111}
112
113impl Debug for InvertedIndex {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("InvertedIndex")
116            .field("params", &self.params)
117            .field("partitions", &self.partitions)
118            .finish()
119    }
120}
121
122impl DeepSizeOf for InvertedIndex {
123    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
124        self.partitions.deep_size_of_children(context)
125    }
126}
127
128impl InvertedIndex {
129    fn to_builder(&self) -> InvertedIndexBuilder {
130        if self.is_legacy() {
131            // for legacy format, we re-create the index in the new format
132            InvertedIndexBuilder::new(self.params.clone())
133        } else {
134            InvertedIndexBuilder::from_existing_index(
135                self.params.clone(),
136                Some(self.store.clone()),
137                self.partitions.iter().map(|part| part.id).collect(),
138            )
139        }
140    }
141
142    pub fn tokenizer(&self) -> tantivy::tokenizer::TextAnalyzer {
143        self.tokenizer.clone()
144    }
145
146    pub fn params(&self) -> &InvertedIndexParams {
147        &self.params
148    }
149
150    // search the documents that contain the query
151    // return the row ids of the documents sorted by bm25 score
152    // ref: https://en.wikipedia.org/wiki/Okapi_BM25
153    // we first calculate in-partition BM25 scores,
154    // then re-calculate the scores for the top k documents across all partitions
155    #[instrument(level = "debug", skip_all)]
156    pub async fn bm25_search(
157        &self,
158        tokens: Arc<Vec<String>>,
159        params: Arc<FtsSearchParams>,
160        operator: Operator,
161        prefilter: Arc<dyn PreFilter>,
162        metrics: Arc<dyn MetricsCollector>,
163    ) -> Result<(Vec<u64>, Vec<f32>)> {
164        let limit = params.limit.unwrap_or(usize::MAX);
165        if limit == 0 {
166            return Ok((Vec::new(), Vec::new()));
167        }
168        let mask = prefilter.mask();
169        let mut candidates = BinaryHeap::new();
170        let parts = self
171            .partitions
172            .iter()
173            .map(|part| {
174                let part = part.clone();
175                let tokens = tokens.clone();
176                let params = params.clone();
177                let mask = mask.clone();
178                let metrics = metrics.clone();
179                tokio::spawn(async move {
180                    part.bm25_search(
181                        tokens.as_ref(),
182                        params.as_ref(),
183                        operator,
184                        mask,
185                        metrics.as_ref(),
186                    )
187                    .await
188                })
189            })
190            .collect::<Vec<_>>();
191        let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
192        let scorer = BM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
193        while let Some(res) = parts.try_next().await? {
194            for (row_id, freq, length) in res? {
195                let mut score = 0.0;
196                for token in tokens.iter() {
197                    score += scorer.score(token, freq, length);
198                }
199                if candidates.len() < limit {
200                    candidates.push(Reverse(ScoredDoc::new(row_id, score)));
201                } else if candidates.peek().unwrap().0.score.0 < score {
202                    candidates.pop();
203                    candidates.push(Reverse(ScoredDoc::new(row_id, score)));
204                }
205            }
206        }
207
208        Ok(candidates
209            .into_sorted_vec()
210            .into_iter()
211            .map(|Reverse(doc)| (doc.row_id, doc.score.0))
212            .unzip())
213    }
214
215    async fn load_legacy_index(
216        store: Arc<dyn IndexStore>,
217        fri: Option<Arc<FragReuseIndex>>,
218    ) -> Result<Arc<Self>> {
219        log::warn!("loading legacy FTS index");
220        let tokens_fut = tokio::spawn({
221            let store = store.clone();
222            async move {
223                let token_reader = store.open_index_file(TOKENS_FILE).await?;
224                let tokenizer = token_reader
225                    .schema()
226                    .metadata
227                    .get("tokenizer")
228                    .map(|s| serde_json::from_str::<InvertedIndexParams>(s))
229                    .transpose()?
230                    .unwrap_or_default();
231                let tokens = TokenSet::load(token_reader).await?;
232                Result::Ok((tokenizer, tokens))
233            }
234        });
235        let invert_list_fut = tokio::spawn({
236            let store = store.clone();
237            async move {
238                let invert_list_reader = store.open_index_file(INVERT_LIST_FILE).await?;
239                let invert_list = PostingListReader::try_new(invert_list_reader).await?;
240                Result::Ok(Arc::new(invert_list))
241            }
242        });
243        let docs_fut = tokio::spawn({
244            let store = store.clone();
245            async move {
246                let docs_reader = store.open_index_file(DOCS_FILE).await?;
247                let docs = DocSet::load(docs_reader, true, fri).await?;
248                Result::Ok(docs)
249            }
250        });
251
252        let (tokenizer_config, tokens) = tokens_fut.await??;
253        let inverted_list = invert_list_fut.await??;
254        let docs = docs_fut.await??;
255
256        let tokenizer = tokenizer_config.build()?;
257
258        Ok(Arc::new(Self {
259            params: tokenizer_config,
260            store: store.clone(),
261            tokenizer,
262            partitions: vec![Arc::new(InvertedPartition {
263                id: 0,
264                store,
265                tokens,
266                inverted_list,
267                docs,
268            })],
269        }))
270    }
271
272    pub fn is_legacy(&self) -> bool {
273        self.partitions.len() == 1 && self.partitions[0].is_legacy()
274    }
275}
276
277#[async_trait]
278impl Index for InvertedIndex {
279    fn as_any(&self) -> &dyn std::any::Any {
280        self
281    }
282
283    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
284        self
285    }
286
287    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
288        Err(Error::invalid_input(
289            "inverted index cannot be cast to vector index",
290            location!(),
291        ))
292    }
293
294    fn statistics(&self) -> Result<serde_json::Value> {
295        let num_tokens = self
296            .partitions
297            .iter()
298            .map(|part| part.tokens.len())
299            .sum::<usize>();
300        let num_docs = self
301            .partitions
302            .iter()
303            .map(|part| part.docs.len())
304            .sum::<usize>();
305        Ok(serde_json::json!({
306            "params": self.params,
307            "num_tokens": num_tokens,
308            "num_docs": num_docs,
309        }))
310    }
311
312    async fn prewarm(&self) -> Result<()> {
313        for part in &self.partitions {
314            part.inverted_list.prewarm().await?;
315        }
316        Ok(())
317    }
318
319    fn index_type(&self) -> crate::IndexType {
320        crate::IndexType::Inverted
321    }
322
323    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
324        unimplemented!()
325    }
326}
327
328#[async_trait]
329impl ScalarIndex for InvertedIndex {
330    // return the row ids of the documents that contain the query
331    #[instrument(level = "debug", skip_all)]
332    async fn search(
333        &self,
334        query: &dyn AnyQuery,
335        _metrics: &dyn MetricsCollector,
336    ) -> Result<SearchResult> {
337        let query = query.as_any().downcast_ref::<SargableQuery>().unwrap();
338        return Err(Error::invalid_input(
339            format!("unsupported query {:?} for inverted index", query),
340            location!(),
341        ));
342    }
343
344    fn can_answer_exact(&self, _: &dyn AnyQuery) -> bool {
345        true
346    }
347
348    async fn load(store: Arc<dyn IndexStore>, fri: Option<Arc<FragReuseIndex>>) -> Result<Arc<Self>>
349    where
350        Self: Sized,
351    {
352        // for new index format, there is a metadata file and multiple partitions,
353        // each partition is a separate index containing tokens, inverted list and docs.
354        // for old index format, there is no metadata file, and it's just like a single partition
355
356        match store.open_index_file(METADATA_FILE).await {
357            Ok(reader) => {
358                let params = reader.schema().metadata.get("params").ok_or(Error::Index {
359                    message: "params not found in metadata".to_owned(),
360                    location: location!(),
361                })?;
362                let params = serde_json::from_str::<InvertedIndexParams>(params)?;
363                let partitions =
364                    reader
365                        .schema()
366                        .metadata
367                        .get("partitions")
368                        .ok_or(Error::Index {
369                            message: "partitions not found in metadata".to_owned(),
370                            location: location!(),
371                        })?;
372                let partitions: Vec<u64> = serde_json::from_str(partitions)?;
373
374                let partitions = partitions.into_iter().map(|id| {
375                    let store = store.clone();
376                    let fri_clone = fri.clone();
377                    async move {
378                        Result::Ok(Arc::new(
379                            InvertedPartition::load(store, id, fri_clone).await?,
380                        ))
381                    }
382                });
383                let partitions = stream::iter(partitions)
384                    .buffer_unordered(store.io_parallelism())
385                    .try_collect::<Vec<_>>()
386                    .await?;
387                let tokenizer = params.build()?;
388                Ok(Arc::new(Self {
389                    params,
390                    store,
391                    tokenizer,
392                    partitions,
393                }))
394            }
395            Err(_) => {
396                // old index format
397                Self::load_legacy_index(store, fri).await
398            }
399        }
400    }
401
402    async fn remap(
403        &self,
404        mapping: &HashMap<u64, Option<u64>>,
405        dest_store: &dyn IndexStore,
406    ) -> Result<()> {
407        self.to_builder()
408            .remap(mapping, self.store.clone(), dest_store)
409            .await
410    }
411
412    async fn update(
413        &self,
414        new_data: SendableRecordBatchStream,
415        dest_store: &dyn IndexStore,
416    ) -> Result<()> {
417        self.to_builder().update(new_data, dest_store).await
418    }
419}
420
421#[derive(Debug, Clone, DeepSizeOf)]
422pub struct InvertedPartition {
423    // None for legacy format
424    id: u64,
425    store: Arc<dyn IndexStore>,
426    pub(crate) tokens: TokenSet,
427    pub(crate) inverted_list: Arc<PostingListReader>,
428    pub(crate) docs: DocSet,
429}
430
431impl InvertedPartition {
432    pub fn id(&self) -> u64 {
433        self.id
434    }
435
436    pub fn store(&self) -> &dyn IndexStore {
437        self.store.as_ref()
438    }
439
440    pub fn is_legacy(&self) -> bool {
441        self.inverted_list.lengths.is_none()
442    }
443
444    pub async fn load(
445        store: Arc<dyn IndexStore>,
446        id: u64,
447        fri: Option<Arc<FragReuseIndex>>,
448    ) -> Result<Self> {
449        let token_file = store.open_index_file(&token_file_path(id)).await?;
450        let tokens = TokenSet::load(token_file).await?;
451        let invert_list_file = store.open_index_file(&posting_file_path(id)).await?;
452        let inverted_list = PostingListReader::try_new(invert_list_file).await?;
453        let docs_file = store.open_index_file(&doc_file_path(id)).await?;
454        let docs = DocSet::load(docs_file, false, fri).await?;
455
456        Ok(Self {
457            id,
458            store,
459            tokens,
460            inverted_list: Arc::new(inverted_list),
461            docs,
462        })
463    }
464
465    fn map(&self, token: &str) -> Option<u32> {
466        self.tokens.get(token)
467    }
468
469    pub fn expand_fuzzy(&self, tokens: &[String], params: &FtsSearchParams) -> Result<Vec<String>> {
470        let mut new_tokens = Vec::with_capacity(min(tokens.len(), params.max_expansions));
471        for token in tokens {
472            let fuzziness = match params.fuzziness {
473                Some(fuzziness) => fuzziness,
474                None => MatchQuery::auto_fuzziness(token),
475            };
476            let lev =
477                fst::automaton::Levenshtein::new(token, fuzziness).map_err(|e| Error::Index {
478                    message: format!("failed to construct the fuzzy query: {}", e),
479                    location: location!(),
480                })?;
481
482            if let TokenMap::Fst(ref map) = self.tokens.tokens {
483                match params.prefix_length {
484                    0 => take_fst_keys(map.search(lev), &mut new_tokens, params.max_expansions),
485                    prefix_length => {
486                        let prefix = &token[..min(prefix_length as usize, token.len())];
487                        let prefix = fst::automaton::Str::new(prefix).starts_with();
488                        take_fst_keys(
489                            map.search(lev.intersection(prefix)),
490                            &mut new_tokens,
491                            params.max_expansions,
492                        )
493                    }
494                }
495            } else {
496                return Err(Error::Index {
497                    message: "tokens is not fst, which is not expected".to_owned(),
498                    location: location!(),
499                });
500            }
501        }
502        Ok(new_tokens)
503    }
504
505    // search the documents that contain the query
506    // return the doc info and the doc length
507    // ref: https://en.wikipedia.org/wiki/Okapi_BM25
508    #[instrument(level = "debug", skip_all)]
509    pub async fn bm25_search(
510        &self,
511        tokens: &[String],
512        params: &FtsSearchParams,
513        operator: Operator,
514        mask: Arc<RowIdMask>,
515        metrics: &dyn MetricsCollector,
516    ) -> Result<Vec<(u64, u32, u32)>> {
517        let is_fuzzy = matches!(params.fuzziness, Some(n) if n != 0);
518        let is_phrase_query = params.phrase_slop.is_some();
519        let tokens = match is_fuzzy {
520            true => self.expand_fuzzy(tokens, params)?,
521            false => tokens.to_vec(),
522        };
523        let mut token_ids = Vec::with_capacity(tokens.len());
524        for token in tokens {
525            let token_id = self.map(&token);
526            if let Some(token_id) = token_id {
527                token_ids.push((token_id, token));
528            } else if is_phrase_query {
529                // if the token is not found, we can't do phrase query
530                return Ok(Vec::new());
531            }
532        }
533        if token_ids.is_empty() {
534            return Ok(Vec::new());
535        }
536        if !is_phrase_query {
537            // remove duplicates
538            token_ids.sort_unstable_by_key(|(token_id, _)| *token_id);
539            token_ids.dedup_by_key(|(token_id, _)| *token_id);
540        }
541
542        let num_docs = self.docs.len();
543        let postings = stream::iter(token_ids)
544            .enumerate()
545            .map(|(position, (token_id, token))| async move {
546                let posting = self
547                    .inverted_list
548                    .posting_list(token_id, is_phrase_query, metrics)
549                    .await?;
550
551                Result::Ok(PostingIterator::new(
552                    token,
553                    token_id,
554                    position as u32,
555                    posting,
556                    num_docs,
557                ))
558            })
559            .buffered(self.store.io_parallelism())
560            .try_collect::<Vec<_>>()
561            .await?;
562        let scorer = BM25Scorer::new(std::iter::once(self));
563        let mut wand = Wand::new(operator, postings.into_iter(), &self.docs, scorer);
564        wand.search(params, mask, metrics)
565    }
566
567    pub async fn into_builder(self) -> Result<InnerBuilder> {
568        let mut builder = InnerBuilder::new(self.id);
569        builder.tokens = self.tokens;
570        builder.docs = self.docs;
571
572        builder
573            .posting_lists
574            .reserve_exact(self.inverted_list.len());
575        for posting_list in self
576            .inverted_list
577            .read_all(self.inverted_list.has_positions())
578            .await?
579        {
580            let posting_list = posting_list?;
581            builder
582                .posting_lists
583                .push(posting_list.into_builder(&builder.docs));
584        }
585        Ok(builder)
586    }
587}
588
589// at indexing, we use HashMap because we need it to be mutable,
590// at searching, we use fst::Map because it's more efficient
591#[derive(Debug, Clone)]
592pub enum TokenMap {
593    HashMap(HashMap<String, u32>),
594    Fst(fst::Map<Vec<u8>>),
595}
596
597impl Default for TokenMap {
598    fn default() -> Self {
599        Self::HashMap(HashMap::new())
600    }
601}
602
603impl DeepSizeOf for TokenMap {
604    fn deep_size_of_children(&self, ctx: &mut deepsize::Context) -> usize {
605        match self {
606            Self::HashMap(map) => map.deep_size_of_children(ctx),
607            Self::Fst(map) => map.as_fst().size(),
608        }
609    }
610}
611
612impl TokenMap {
613    pub fn len(&self) -> usize {
614        match self {
615            Self::HashMap(map) => map.len(),
616            Self::Fst(map) => map.len(),
617        }
618    }
619
620    pub fn is_empty(&self) -> bool {
621        self.len() == 0
622    }
623}
624
625// TokenSet is a mapping from tokens to token ids
626#[derive(Debug, Clone, Default, DeepSizeOf)]
627pub struct TokenSet {
628    // token -> token_id
629    pub(crate) tokens: TokenMap,
630    pub(crate) next_id: u32,
631    total_length: usize,
632}
633
634impl TokenSet {
635    pub fn into_mut(self) -> Self {
636        let tokens = match self.tokens {
637            TokenMap::HashMap(map) => map,
638            TokenMap::Fst(map) => {
639                let mut new_map = HashMap::with_capacity(map.len());
640                let mut stream = map.into_stream();
641                while let Some((token, token_id)) = stream.next() {
642                    new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
643                }
644
645                new_map
646            }
647        };
648
649        Self {
650            tokens: TokenMap::HashMap(tokens),
651            next_id: self.next_id,
652            total_length: self.total_length,
653        }
654    }
655
656    pub fn len(&self) -> usize {
657        self.tokens.len()
658    }
659
660    pub fn is_empty(&self) -> bool {
661        self.len() == 0
662    }
663
664    pub(crate) fn iter(&self) -> TokenIterator {
665        TokenIterator::new(match &self.tokens {
666            TokenMap::HashMap(map) => TokenSource::HashMap(map.iter()),
667            TokenMap::Fst(map) => TokenSource::Fst(map.stream()),
668        })
669    }
670
671    pub fn to_batch(self) -> Result<RecordBatch> {
672        let mut token_builder = StringBuilder::with_capacity(self.tokens.len(), self.total_length);
673        let mut token_id_builder = UInt32Builder::with_capacity(self.tokens.len());
674
675        match self.tokens {
676            TokenMap::Fst(map) => {
677                let mut stream = map.stream();
678                while let Some((token, token_id)) = stream.next() {
679                    token_builder.append_value(String::from_utf8_lossy(token));
680                    token_id_builder.append_value(token_id as u32);
681                }
682            }
683            TokenMap::HashMap(map) => {
684                for (token, token_id) in map.into_iter().sorted_unstable() {
685                    token_builder.append_value(token);
686                    token_id_builder.append_value(token_id);
687                }
688            }
689        }
690
691        let token_col = token_builder.finish();
692        let token_id_col = token_id_builder.finish();
693
694        let schema = arrow_schema::Schema::new(vec![
695            arrow_schema::Field::new(TOKEN_COL, DataType::Utf8, false),
696            arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false),
697        ]);
698
699        let batch = RecordBatch::try_new(
700            Arc::new(schema),
701            vec![
702                Arc::new(token_col) as ArrayRef,
703                Arc::new(token_id_col) as ArrayRef,
704            ],
705        )?;
706        Ok(batch)
707    }
708
709    pub async fn load(reader: Arc<dyn IndexReader>) -> Result<Self> {
710        let mut next_id = 0;
711        let mut total_length = 0;
712        let mut tokens = fst::MapBuilder::memory();
713
714        let batch = reader.read_range(0..reader.num_rows(), None).await?;
715        let token_col = batch[TOKEN_COL].as_string::<i32>();
716        let token_id_col = batch[TOKEN_ID_COL].as_primitive::<datatypes::UInt32Type>();
717
718        for (token, &token_id) in token_col.iter().zip(token_id_col.values().iter()) {
719            let token = token.ok_or(Error::Index {
720                message: "found null token in token set".to_owned(),
721                location: location!(),
722            })?;
723            next_id = next_id.max(token_id + 1);
724            total_length += token.len();
725            tokens
726                .insert(token, token_id as u64)
727                .map_err(|e| Error::Index {
728                    message: format!("failed to insert token {}: {}", token, e),
729                    location: location!(),
730                })?;
731        }
732
733        Ok(Self {
734            tokens: TokenMap::Fst(tokens.into_map()),
735            next_id,
736            total_length,
737        })
738    }
739
740    pub fn add(&mut self, token: String) -> u32 {
741        let next_id = self.next_id();
742        let len = token.len();
743        let token_id = match self.tokens {
744            TokenMap::HashMap(ref mut map) => *map.entry(token).or_insert(next_id),
745            _ => unreachable!("tokens must be HashMap while indexing"),
746        };
747
748        // add token if it doesn't exist
749        if token_id == next_id {
750            self.next_id += 1;
751            self.total_length += len;
752        }
753
754        token_id
755    }
756
757    pub fn get(&self, token: &str) -> Option<u32> {
758        match self.tokens {
759            TokenMap::HashMap(ref map) => map.get(token).copied(),
760            TokenMap::Fst(ref map) => map.get(token).map(|id| id as u32),
761        }
762    }
763
764    pub fn next_id(&self) -> u32 {
765        self.next_id
766    }
767}
768
769pub struct PostingListReader {
770    reader: Arc<dyn IndexReader>,
771
772    // legacy format only
773    offsets: Option<Vec<usize>>,
774
775    // from metadata for legacy format
776    // from column for new format
777    max_scores: Option<Vec<f32>>,
778
779    // new format only
780    lengths: Option<Vec<u32>>,
781
782    has_position: bool,
783
784    // cache
785    posting_cache: Cache<u32, PostingList>,
786    position_cache: Cache<u32, ListArray>,
787}
788
789impl std::fmt::Debug for PostingListReader {
790    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
791        f.debug_struct("InvertedListReader")
792            .field("offsets", &self.offsets)
793            .field("max_scores", &self.max_scores)
794            .finish()
795    }
796}
797
798impl DeepSizeOf for PostingListReader {
799    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
800        self.offsets.deep_size_of_children(context)
801            + self.max_scores.deep_size_of_children(context)
802            + self.lengths.deep_size_of_children(context)
803            + self.posting_cache.weighted_size() as usize
804            + self.position_cache.weighted_size() as usize
805    }
806}
807
808impl PostingListReader {
809    pub(crate) async fn try_new(reader: Arc<dyn IndexReader>) -> Result<Self> {
810        let has_position = reader.schema().field(POSITION_COL).is_some();
811        let (offsets, max_scores, lengths) = if reader.schema().field(POSTING_COL).is_none() {
812            let (offsets, max_scores) = Self::load_metadata(reader.schema())?;
813            (Some(offsets), max_scores, None)
814        } else {
815            let metadata = reader
816                .read_range(0..reader.num_rows(), Some(&[MAX_SCORE_COL, LENGTH_COL]))
817                .await?;
818            let max_scores = metadata[MAX_SCORE_COL]
819                .as_primitive::<Float32Type>()
820                .values()
821                .to_vec();
822            let lengths = metadata[LENGTH_COL]
823                .as_primitive::<UInt32Type>()
824                .values()
825                .to_vec();
826            (None, Some(max_scores), Some(lengths))
827        };
828
829        let posting_cache = Cache::builder()
830            .max_capacity(*CACHE_SIZE as u64)
831            .weigher(|_, posting: &PostingList| posting.deep_size_of() as u32)
832            .build();
833        let position_cache = Cache::builder()
834            .max_capacity(*CACHE_SIZE as u64)
835            .weigher(|_, positions: &ListArray| positions.get_array_memory_size() as u32)
836            .build();
837
838        Ok(Self {
839            reader,
840            offsets,
841            max_scores,
842            lengths,
843            has_position,
844            posting_cache,
845            position_cache,
846        })
847    }
848
849    // for legacy format
850    // returns the offsets and max scores
851    fn load_metadata(
852        schema: &lance_core::datatypes::Schema,
853    ) -> Result<(Vec<usize>, Option<Vec<f32>>)> {
854        let offsets = schema.metadata.get("offsets").ok_or(Error::Index {
855            message: "offsets not found in metadata".to_owned(),
856            location: location!(),
857        })?;
858        let offsets = serde_json::from_str(offsets)?;
859
860        let max_scores = schema
861            .metadata
862            .get("max_scores")
863            .map(|max_scores| serde_json::from_str(max_scores))
864            .transpose()?;
865        Ok((offsets, max_scores))
866    }
867
868    // the number of posting lists
869    pub fn len(&self) -> usize {
870        match self.offsets {
871            Some(ref offsets) => offsets.len(),
872            None => self.reader.num_rows(),
873        }
874    }
875
876    pub fn is_empty(&self) -> bool {
877        self.len() == 0
878    }
879
880    pub(crate) fn has_positions(&self) -> bool {
881        self.has_position
882    }
883
884    pub(crate) fn posting_len(&self, token_id: u32) -> usize {
885        let token_id = token_id as usize;
886
887        match self.offsets {
888            Some(ref offsets) => {
889                let next_offset = offsets
890                    .get(token_id + 1)
891                    .copied()
892                    .unwrap_or(self.reader.num_rows());
893                next_offset - offsets[token_id]
894            }
895            None => {
896                if let Some(lengths) = &self.lengths {
897                    lengths[token_id] as usize
898                } else {
899                    panic!("posting list reader is not initialized")
900                }
901            }
902        }
903    }
904
905    pub(crate) async fn posting_batch(
906        &self,
907        token_id: u32,
908        with_position: bool,
909    ) -> Result<RecordBatch> {
910        if self.offsets.is_some() {
911            self.posting_batch_legacy(token_id, with_position).await
912        } else {
913            let token_id = token_id as usize;
914            let columns = if with_position {
915                vec![POSTING_COL, POSITION_COL]
916            } else {
917                vec![POSTING_COL]
918            };
919            let batch = self
920                .reader
921                .read_range(token_id..token_id + 1, Some(&columns))
922                .await?;
923            Ok(batch)
924        }
925    }
926
927    async fn posting_batch_legacy(
928        &self,
929        token_id: u32,
930        with_position: bool,
931    ) -> Result<RecordBatch> {
932        let mut columns = vec![ROW_ID, FREQUENCY_COL];
933        if with_position {
934            columns.push(POSITION_COL);
935        }
936
937        let length = self.posting_len(token_id);
938        let token_id = token_id as usize;
939        let offset = self.offsets.as_ref().unwrap()[token_id];
940        let batch = self
941            .reader
942            .read_range(offset..offset + length, Some(&columns))
943            .await?;
944        Ok(batch)
945    }
946
947    #[instrument(level = "debug", skip(self, metrics))]
948    pub(crate) async fn posting_list(
949        &self,
950        token_id: u32,
951        is_phrase_query: bool,
952        metrics: &dyn MetricsCollector,
953    ) -> Result<PostingList> {
954        let mut posting = self
955            .posting_cache
956            .try_get_with(token_id, async move {
957                metrics.record_part_load();
958                info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="inverted", part_id=token_id);
959                let batch = self.posting_batch(token_id, false).await?;
960               self.posting_list_from_batch(&batch, token_id)
961            })
962            .await
963            .map_err(|e| Error::io(e.to_string(), location!()))?;
964
965        if is_phrase_query {
966            // hit the cache and when the cache was populated, the positions column was not loaded
967            let positions = self.read_positions(token_id).await?;
968            posting.set_positions(positions);
969        }
970
971        Ok(posting)
972    }
973
974    pub(crate) fn posting_list_from_batch(
975        &self,
976        batch: &RecordBatch,
977        token_id: u32,
978    ) -> Result<PostingList> {
979        let posting_list = PostingList::from_batch(
980            batch,
981            self.max_scores
982                .as_ref()
983                .map(|max_scores| max_scores[token_id as usize]),
984            self.lengths
985                .as_ref()
986                .map(|lengths| lengths[token_id as usize]),
987        )?;
988        Ok(posting_list)
989    }
990
991    async fn prewarm(&self) -> Result<()> {
992        let batch = self.read_batch(false).await?;
993        for token_id in 0..self.len() {
994            let posting_range = self.posting_list_range(token_id as u32);
995            let batch = batch.slice(posting_range.start, posting_range.end - posting_range.start);
996            let posting_list = self.posting_list_from_batch(&batch, token_id as u32)?;
997            self.posting_cache
998                .insert(token_id as u32, posting_list)
999                .await;
1000        }
1001
1002        Ok(())
1003    }
1004
1005    pub(crate) async fn read_batch(&self, with_position: bool) -> Result<RecordBatch> {
1006        let columns = self.posting_columns(with_position);
1007        let batch = self
1008            .reader
1009            .read_range(0..self.reader.num_rows(), Some(&columns))
1010            .await?;
1011        Ok(batch)
1012    }
1013
1014    pub(crate) async fn read_all(
1015        &self,
1016        with_position: bool,
1017    ) -> Result<impl Iterator<Item = Result<PostingList>> + '_> {
1018        let batch = self.read_batch(with_position).await?;
1019        Ok((0..self.len()).map(move |i| {
1020            let token_id = i as u32;
1021            let range = self.posting_list_range(token_id);
1022            let batch = batch.slice(i, range.end - range.start);
1023            self.posting_list_from_batch(&batch, token_id)
1024        }))
1025    }
1026
1027    async fn read_positions(&self, token_id: u32) -> Result<ListArray> {
1028        self.position_cache.try_get_with(token_id, async move {
1029            let batch = self
1030                .reader
1031                .read_range(self.posting_list_range(token_id), Some(&[POSITION_COL]))
1032                .await.map_err(|e| {
1033                    match e {
1034                        Error::Schema { .. } => Error::Index {
1035                            message: "position is not found but required for phrase queries, try recreating the index with position".to_owned(),
1036                            location: location!(),
1037                        },
1038                        e => e
1039                    }
1040                })?;
1041            Result::Ok(batch[POSITION_COL]
1042                .as_list::<i32>()
1043                .clone())
1044        }).await.map_err(|e| Error::io(e.to_string(), location!()))
1045    }
1046
1047    fn posting_list_range(&self, token_id: u32) -> Range<usize> {
1048        match self.offsets {
1049            Some(ref offsets) => {
1050                let offset = offsets[token_id as usize];
1051                let posting_len = self.posting_len(token_id);
1052                offset..offset + posting_len
1053            }
1054            None => {
1055                let token_id = token_id as usize;
1056                token_id..token_id + 1
1057            }
1058        }
1059    }
1060
1061    fn posting_columns(&self, with_position: bool) -> Vec<&'static str> {
1062        let mut base_columns = match self.offsets {
1063            Some(_) => vec![ROW_ID, FREQUENCY_COL],
1064            None => vec![POSTING_COL],
1065        };
1066        if with_position {
1067            base_columns.push(POSITION_COL);
1068        }
1069        base_columns
1070    }
1071}
1072
1073#[derive(Debug, Clone, DeepSizeOf)]
1074pub enum PostingList {
1075    Plain(PlainPostingList),
1076    Compressed(CompressedPostingList),
1077}
1078
1079impl PostingList {
1080    pub fn from_batch(
1081        batch: &RecordBatch,
1082        max_score: Option<f32>,
1083        length: Option<u32>,
1084    ) -> Result<Self> {
1085        match batch.column_by_name(POSTING_COL) {
1086            Some(_) => {
1087                debug_assert!(max_score.is_some() && length.is_some());
1088                let posting =
1089                    CompressedPostingList::from_batch(batch, max_score.unwrap(), length.unwrap());
1090                Ok(Self::Compressed(posting))
1091            }
1092            None => {
1093                let posting = PlainPostingList::from_batch(batch, max_score);
1094                Ok(Self::Plain(posting))
1095            }
1096        }
1097    }
1098
1099    pub fn iter(&self) -> PostingListIterator {
1100        PostingListIterator::new(self)
1101    }
1102
1103    pub fn has_position(&self) -> bool {
1104        match self {
1105            Self::Plain(posting) => posting.positions.is_some(),
1106            Self::Compressed(posting) => posting.positions.is_some(),
1107        }
1108    }
1109
1110    pub fn set_positions(&mut self, positions: ListArray) {
1111        match self {
1112            Self::Plain(posting) => posting.positions = Some(positions),
1113            Self::Compressed(posting) => {
1114                posting.positions = Some(positions.value(0).as_list::<i32>().clone());
1115            }
1116        }
1117    }
1118
1119    pub fn max_score(&self) -> Option<f32> {
1120        match self {
1121            Self::Plain(posting) => posting.max_score,
1122            Self::Compressed(posting) => Some(posting.max_score),
1123        }
1124    }
1125
1126    pub fn len(&self) -> usize {
1127        match self {
1128            Self::Plain(posting) => posting.len(),
1129            Self::Compressed(posting) => posting.length as usize,
1130        }
1131    }
1132
1133    pub fn is_empty(&self) -> bool {
1134        self.len() == 0
1135    }
1136
1137    pub fn into_builder(self, docs: &DocSet) -> PostingListBuilder {
1138        let mut builder = PostingListBuilder::new(self.has_position());
1139        match self {
1140            // legacy format
1141            Self::Plain(posting) => {
1142                // convert the posting list to the new format:
1143                // 1. map row ids to doc ids
1144                // 2. sort the posting list by doc ids
1145                struct Item {
1146                    doc_id: u32,
1147                    positions: PositionRecorder,
1148                }
1149                let doc_ids = docs
1150                    .row_ids
1151                    .iter()
1152                    .enumerate()
1153                    .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
1154                    .collect::<HashMap<_, _>>();
1155                let mut items = Vec::with_capacity(posting.len());
1156                for (row_id, freq, positions) in posting.iter() {
1157                    let freq = freq as u32;
1158                    let positions = match positions {
1159                        Some(positions) => {
1160                            PositionRecorder::Position(positions.collect::<Vec<_>>())
1161                        }
1162                        None => PositionRecorder::Count(freq),
1163                    };
1164                    items.push(Item {
1165                        doc_id: doc_ids[&row_id],
1166                        positions,
1167                    });
1168                }
1169                items.sort_unstable_by_key(|item| item.doc_id);
1170                for item in items {
1171                    builder.add(item.doc_id, item.positions);
1172                }
1173            }
1174            Self::Compressed(posting) => {
1175                posting.iter().for_each(|(doc_id, freq, positions)| {
1176                    let positions = match positions {
1177                        Some(positions) => {
1178                            PositionRecorder::Position(positions.collect::<Vec<_>>())
1179                        }
1180                        None => PositionRecorder::Count(freq),
1181                    };
1182                    builder.add(doc_id, positions);
1183                });
1184            }
1185        }
1186        builder
1187    }
1188}
1189
1190#[derive(Debug, PartialEq, Clone)]
1191pub struct PlainPostingList {
1192    pub row_ids: ScalarBuffer<u64>,
1193    pub frequencies: ScalarBuffer<f32>,
1194    pub max_score: Option<f32>,
1195    pub positions: Option<ListArray>, // List of Int32
1196}
1197
1198impl DeepSizeOf for PlainPostingList {
1199    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1200        self.row_ids.len() * std::mem::size_of::<u64>()
1201            + self.frequencies.len() * std::mem::size_of::<u32>()
1202            + self
1203                .positions
1204                .as_ref()
1205                .map(|positions| positions.get_array_memory_size())
1206                .unwrap_or(0)
1207    }
1208}
1209
1210impl PlainPostingList {
1211    pub fn new(
1212        row_ids: ScalarBuffer<u64>,
1213        frequencies: ScalarBuffer<f32>,
1214        max_score: Option<f32>,
1215        positions: Option<ListArray>,
1216    ) -> Self {
1217        Self {
1218            row_ids,
1219            frequencies,
1220            max_score,
1221            positions,
1222        }
1223    }
1224
1225    pub fn from_batch(batch: &RecordBatch, max_score: Option<f32>) -> Self {
1226        let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().values().clone();
1227        let frequencies = batch[FREQUENCY_COL]
1228            .as_primitive::<Float32Type>()
1229            .values()
1230            .clone();
1231        let positions = batch
1232            .column_by_name(POSITION_COL)
1233            .map(|col| col.as_list::<i32>().clone());
1234
1235        Self::new(row_ids, frequencies, max_score, positions)
1236    }
1237
1238    pub fn len(&self) -> usize {
1239        self.row_ids.len()
1240    }
1241
1242    pub fn is_empty(&self) -> bool {
1243        self.len() == 0
1244    }
1245
1246    pub fn iter(&self) -> PlainPostingListIterator {
1247        Box::new(
1248            self.row_ids
1249                .iter()
1250                .zip(self.frequencies.iter())
1251                .enumerate()
1252                .map(|(idx, (doc_id, freq))| {
1253                    (
1254                        *doc_id,
1255                        *freq,
1256                        self.positions.as_ref().map(|p| {
1257                            let start = p.value_offsets()[idx] as usize;
1258                            let end = p.value_offsets()[idx + 1] as usize;
1259                            Box::new(
1260                                p.values().as_primitive::<Int32Type>().values()[start..end]
1261                                    .iter()
1262                                    .map(|pos| *pos as u32),
1263                            ) as _
1264                        }),
1265                    )
1266                }),
1267        )
1268    }
1269
1270    #[inline]
1271    pub fn doc(&self, i: usize) -> LocatedDocInfo {
1272        LocatedDocInfo::new(self.row_ids[i], self.frequencies[i])
1273    }
1274
1275    pub fn positions(&self, index: usize) -> Option<Arc<dyn Array>> {
1276        self.positions
1277            .as_ref()
1278            .map(|positions| positions.value(index))
1279    }
1280
1281    pub fn max_score(&self) -> Option<f32> {
1282        self.max_score
1283    }
1284
1285    pub fn row_id(&self, i: usize) -> u64 {
1286        self.row_ids[i]
1287    }
1288}
1289
1290#[derive(Debug, PartialEq, Clone)]
1291pub struct CompressedPostingList {
1292    pub max_score: f32,
1293    pub length: u32,
1294    // each binary is a block of compressed data
1295    // that contains `BLOCK_SIZE` doc ids and then `BLOCK_SIZE` frequencies
1296    pub blocks: LargeBinaryArray,
1297    pub positions: Option<ListArray>,
1298}
1299
1300impl DeepSizeOf for CompressedPostingList {
1301    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1302        self.blocks.get_array_memory_size()
1303            + self
1304                .positions
1305                .as_ref()
1306                .map(|positions| positions.get_array_memory_size())
1307                .unwrap_or(0)
1308    }
1309}
1310
1311impl CompressedPostingList {
1312    pub fn new(
1313        blocks: LargeBinaryArray,
1314        max_score: f32,
1315        length: u32,
1316        positions: Option<ListArray>,
1317    ) -> Self {
1318        Self {
1319            max_score,
1320            length,
1321            blocks,
1322            positions,
1323        }
1324    }
1325
1326    pub fn from_batch(batch: &RecordBatch, max_score: f32, length: u32) -> Self {
1327        debug_assert_eq!(batch.num_rows(), 1);
1328        let blocks = batch[POSTING_COL]
1329            .as_list::<i32>()
1330            .value(0)
1331            .as_binary::<i64>()
1332            .clone();
1333        let positions = batch
1334            .column_by_name(POSITION_COL)
1335            .map(|col| col.as_list::<i32>().value(0).as_list::<i32>().clone());
1336
1337        Self {
1338            max_score,
1339            length,
1340            blocks,
1341            positions,
1342        }
1343    }
1344
1345    pub fn iter(&self) -> CompressedPostingListIterator {
1346        CompressedPostingListIterator::new(
1347            self.length as usize,
1348            self.blocks.clone(),
1349            self.positions.clone(),
1350        )
1351    }
1352
1353    pub fn block_max_score(&self, block_idx: usize) -> f32 {
1354        let block = self.blocks.value(block_idx);
1355        block[0..4].try_into().map(f32::from_le_bytes).unwrap()
1356    }
1357
1358    pub fn block_least_doc_id(&self, block_idx: usize) -> u32 {
1359        let block = self.blocks.value(block_idx);
1360        block[4..8].try_into().map(u32::from_le_bytes).unwrap()
1361    }
1362}
1363
1364#[derive(Debug)]
1365pub struct PostingListBuilder {
1366    pub doc_ids: ExpLinkedList<u32>,
1367    pub frequencies: ExpLinkedList<u32>,
1368    pub positions: Option<PositionBuilder>,
1369}
1370
1371impl PostingListBuilder {
1372    pub fn size(&self) -> u64 {
1373        (std::mem::size_of::<u32>() * self.doc_ids.len()
1374            + std::mem::size_of::<u32>() * self.frequencies.len()
1375            + self
1376                .positions
1377                .as_ref()
1378                .map(|positions| positions.size())
1379                .unwrap_or(0)) as u64
1380    }
1381
1382    pub fn has_positions(&self) -> bool {
1383        self.positions.is_some()
1384    }
1385
1386    pub fn new(with_position: bool) -> Self {
1387        Self {
1388            doc_ids: ExpLinkedList::new().with_capacity_limit(128),
1389            frequencies: ExpLinkedList::new().with_capacity_limit(128),
1390            positions: with_position.then(PositionBuilder::new),
1391        }
1392    }
1393
1394    pub fn len(&self) -> usize {
1395        self.doc_ids.len()
1396    }
1397
1398    pub fn is_empty(&self) -> bool {
1399        self.len() == 0
1400    }
1401
1402    pub fn iter(&self) -> impl Iterator<Item = (&u32, &u32, Option<&[u32]>)> {
1403        self.doc_ids
1404            .iter()
1405            .zip(self.frequencies.iter())
1406            .enumerate()
1407            .map(|(idx, (doc_id, freq))| {
1408                let positions = self.positions.as_ref().map(|positions| positions.get(idx));
1409                (doc_id, freq, positions)
1410            })
1411    }
1412
1413    pub fn add(&mut self, doc_id: u32, term_positions: PositionRecorder) {
1414        self.doc_ids.push(doc_id);
1415        self.frequencies.push(term_positions.len());
1416        if let Some(positions) = self.positions.as_mut() {
1417            positions.push(term_positions.into_vec());
1418        }
1419    }
1420
1421    // assume the posting list is sorted by doc id
1422    pub fn to_batch(mut self, block_max_scores: Vec<f32>) -> Result<RecordBatch> {
1423        let length = self.len();
1424        let mut position_builder = self.positions.as_mut().map(|_| {
1425            ListBuilder::new(ListBuilder::with_capacity(
1426                LargeBinaryBuilder::new(),
1427                length,
1428            ))
1429        });
1430        let max_score = block_max_scores.iter().copied().fold(f32::MIN, f32::max);
1431        for index in 0..length {
1432            if let Some(position_builder) = position_builder.as_mut() {
1433                let positions = self.positions.as_ref().unwrap().get(index);
1434                let compressed = compress_positions(positions)?;
1435                let inner_builder = position_builder.values();
1436                inner_builder.append_value(compressed.into_iter());
1437            }
1438        }
1439        let compressed = compress_posting_list(
1440            self.doc_ids.len(),
1441            self.doc_ids.iter(),
1442            self.frequencies.iter(),
1443            block_max_scores.into_iter(),
1444        )?;
1445        let schema = inverted_list_schema(self.has_positions());
1446        let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, compressed.len() as i32]));
1447        let mut columns = vec![
1448            Arc::new(ListArray::try_new(
1449                Arc::new(Field::new("item", datatypes::DataType::LargeBinary, true)),
1450                offsets,
1451                Arc::new(compressed),
1452                None,
1453            )?) as ArrayRef,
1454            Arc::new(Float32Array::from_iter_values(std::iter::once(max_score))) as ArrayRef,
1455            Arc::new(UInt32Array::from_iter_values(std::iter::once(
1456                self.len() as u32
1457            ))) as ArrayRef,
1458        ];
1459
1460        if let Some(mut position_builder) = position_builder {
1461            position_builder.append(true);
1462            let position_col = position_builder.finish();
1463            columns.push(Arc::new(position_col));
1464        }
1465        let batch = RecordBatch::try_new(schema, columns)?;
1466        Ok(batch)
1467    }
1468
1469    pub fn remap(&mut self, removed: &[u32]) {
1470        let mut cursor = 0;
1471        let mut new_doc_ids = ExpLinkedList::with_capacity(self.len());
1472        let mut new_frequencies = ExpLinkedList::with_capacity(self.len());
1473        let mut new_positions = self.positions.as_mut().map(|_| PositionBuilder::new());
1474        for (&doc_id, &freq, positions) in self.iter() {
1475            while cursor < removed.len() && removed[cursor] < doc_id {
1476                cursor += 1;
1477            }
1478            if cursor < removed.len() && removed[cursor] == doc_id {
1479                // this doc is removed
1480                continue;
1481            }
1482            // there are cursor removed docs before this doc
1483            // so we need to shift the doc id
1484            new_doc_ids.push(doc_id - cursor as u32);
1485            new_frequencies.push(freq);
1486            if let Some(new_positions) = new_positions.as_mut() {
1487                new_positions.push(positions.unwrap().to_vec());
1488            }
1489        }
1490
1491        self.doc_ids = new_doc_ids;
1492        self.frequencies = new_frequencies;
1493        self.positions = new_positions;
1494    }
1495}
1496
1497#[derive(Debug, Clone, DeepSizeOf)]
1498pub struct PositionBuilder {
1499    positions: Vec<u32>,
1500    offsets: Vec<i32>,
1501}
1502
1503impl Default for PositionBuilder {
1504    fn default() -> Self {
1505        Self::new()
1506    }
1507}
1508
1509impl PositionBuilder {
1510    pub fn new() -> Self {
1511        Self {
1512            positions: Vec::new(),
1513            offsets: vec![0],
1514        }
1515    }
1516
1517    pub fn size(&self) -> usize {
1518        std::mem::size_of::<u32>() * self.positions.len()
1519            + std::mem::size_of::<i32>() * (self.offsets.len() - 1)
1520    }
1521
1522    pub fn total_len(&self) -> usize {
1523        self.positions.len()
1524    }
1525
1526    pub fn push(&mut self, positions: Vec<u32>) {
1527        self.positions.extend(positions);
1528        self.offsets.push(self.positions.len() as i32);
1529    }
1530
1531    pub fn get(&self, i: usize) -> &[u32] {
1532        let start = self.offsets[i] as usize;
1533        let end = self.offsets[i + 1] as usize;
1534        &self.positions[start..end]
1535    }
1536}
1537
1538impl From<Vec<Vec<u32>>> for PositionBuilder {
1539    fn from(positions: Vec<Vec<u32>>) -> Self {
1540        let mut builder = Self::new();
1541        builder.offsets.reserve(positions.len());
1542        for pos in positions {
1543            builder.push(pos);
1544        }
1545        builder
1546    }
1547}
1548
1549#[derive(Debug, Clone, DeepSizeOf, Copy)]
1550pub enum DocInfo {
1551    Located(LocatedDocInfo),
1552    Raw(RawDocInfo),
1553}
1554
1555impl DocInfo {
1556    pub fn doc_id(&self) -> u64 {
1557        match self {
1558            Self::Raw(info) => info.doc_id as u64,
1559            Self::Located(info) => info.row_id,
1560        }
1561    }
1562
1563    pub fn frequency(&self) -> u32 {
1564        match self {
1565            Self::Raw(info) => info.frequency,
1566            Self::Located(info) => info.frequency as u32,
1567        }
1568    }
1569}
1570
1571impl Eq for DocInfo {}
1572
1573impl PartialEq for DocInfo {
1574    fn eq(&self, other: &Self) -> bool {
1575        self.doc_id() == other.doc_id()
1576    }
1577}
1578
1579impl PartialOrd for DocInfo {
1580    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1581        Some(self.cmp(other))
1582    }
1583}
1584
1585impl Ord for DocInfo {
1586    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1587        self.doc_id().cmp(&other.doc_id())
1588    }
1589}
1590
1591#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
1592pub struct LocatedDocInfo {
1593    pub row_id: u64,
1594    pub frequency: f32,
1595}
1596
1597impl LocatedDocInfo {
1598    pub fn new(row_id: u64, frequency: f32) -> Self {
1599        Self { row_id, frequency }
1600    }
1601}
1602
1603impl Eq for LocatedDocInfo {}
1604
1605impl PartialEq for LocatedDocInfo {
1606    fn eq(&self, other: &Self) -> bool {
1607        self.row_id == other.row_id
1608    }
1609}
1610
1611impl PartialOrd for LocatedDocInfo {
1612    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1613        Some(self.cmp(other))
1614    }
1615}
1616
1617impl Ord for LocatedDocInfo {
1618    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1619        self.row_id.cmp(&other.row_id)
1620    }
1621}
1622
1623#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
1624pub struct RawDocInfo {
1625    pub doc_id: u32,
1626    pub frequency: u32,
1627}
1628
1629impl RawDocInfo {
1630    pub fn new(doc_id: u32, frequency: u32) -> Self {
1631        Self { doc_id, frequency }
1632    }
1633}
1634
1635impl Eq for RawDocInfo {}
1636
1637impl PartialEq for RawDocInfo {
1638    fn eq(&self, other: &Self) -> bool {
1639        self.doc_id == other.doc_id
1640    }
1641}
1642
1643impl PartialOrd for RawDocInfo {
1644    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1645        Some(self.cmp(other))
1646    }
1647}
1648
1649impl Ord for RawDocInfo {
1650    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1651        self.doc_id.cmp(&other.doc_id)
1652    }
1653}
1654
1655// DocSet is a mapping from row ids to the number of tokens in the document
1656// It's used to sort the documents by the bm25 score
1657#[derive(Debug, Clone, Default, DeepSizeOf)]
1658pub struct DocSet {
1659    row_ids: Vec<u64>,
1660    num_tokens: Vec<u32>,
1661    total_tokens: u64,
1662}
1663
1664impl DocSet {
1665    #[inline]
1666    pub fn len(&self) -> usize {
1667        self.row_ids.len()
1668    }
1669
1670    pub fn is_empty(&self) -> bool {
1671        self.len() == 0
1672    }
1673
1674    pub fn iter(&self) -> impl Iterator<Item = (&u64, &u32)> {
1675        self.row_ids.iter().zip(self.num_tokens.iter())
1676    }
1677
1678    pub fn row_id(&self, doc_id: u32) -> u64 {
1679        self.row_ids[doc_id as usize]
1680    }
1681
1682    pub fn row_range(&self) -> RangeInclusive<u64> {
1683        self.row_ids[0]..=self.row_ids[self.len() - 1]
1684    }
1685
1686    pub fn total_tokens_num(&self) -> u64 {
1687        self.total_tokens
1688    }
1689
1690    #[inline]
1691    pub fn average_length(&self) -> f32 {
1692        self.total_tokens as f32 / self.len() as f32
1693    }
1694
1695    pub fn calculate_block_max_scores<'a>(
1696        &self,
1697        doc_ids: impl Iterator<Item = &'a u32>,
1698        freqs: impl Iterator<Item = &'a u32>,
1699    ) -> Vec<f32> {
1700        let avgdl = self.average_length();
1701        let length = doc_ids.size_hint().0;
1702        let mut block_max_scores = Vec::with_capacity(length);
1703        let mut max_score = f32::MIN;
1704        for (i, (doc_id, freq)) in doc_ids.zip(freqs).enumerate() {
1705            let doc_norm = K1 * (1.0 - B + B * self.num_tokens(*doc_id) as f32 / avgdl);
1706            let freq = *freq as f32;
1707            let score = freq / (freq + doc_norm);
1708            if score > max_score {
1709                max_score = score;
1710            }
1711            if (i + 1) % BLOCK_SIZE == 0 {
1712                max_score *= idf(length, self.len());
1713                block_max_scores.push(max_score);
1714                max_score = f32::MIN;
1715            }
1716        }
1717        if length % BLOCK_SIZE > 0 {
1718            max_score *= idf(length, self.len());
1719            block_max_scores.push(max_score);
1720        }
1721        block_max_scores
1722    }
1723
1724    pub fn to_batch(&self) -> Result<RecordBatch> {
1725        let row_id_col = UInt64Array::from_iter_values(self.row_ids.iter().cloned());
1726        let num_tokens_col = UInt32Array::from_iter_values(self.num_tokens.iter().cloned());
1727
1728        let schema = arrow_schema::Schema::new(vec![
1729            arrow_schema::Field::new(ROW_ID, DataType::UInt64, false),
1730            arrow_schema::Field::new(NUM_TOKEN_COL, DataType::UInt32, false),
1731        ]);
1732
1733        let batch = RecordBatch::try_new(
1734            Arc::new(schema),
1735            vec![
1736                Arc::new(row_id_col) as ArrayRef,
1737                Arc::new(num_tokens_col) as ArrayRef,
1738            ],
1739        )?;
1740        Ok(batch)
1741    }
1742
1743    pub async fn load(
1744        reader: Arc<dyn IndexReader>,
1745        is_legacy: bool,
1746        fri: Option<Arc<FragReuseIndex>>,
1747    ) -> Result<Self> {
1748        let batch = reader.read_range(0..reader.num_rows(), None).await?;
1749        let row_id_col = batch[ROW_ID].as_primitive::<datatypes::UInt64Type>();
1750        let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::<datatypes::UInt32Type>();
1751
1752        let (row_ids, num_tokens) = match is_legacy {
1753            // for legacy format, the row id is doc id,
1754            // in order to support efficient search, we need to sort the row ids,
1755            // so that we can use binary search to get num_tokens
1756            true => row_id_col
1757                .values()
1758                .iter()
1759                .filter_map(|id| {
1760                    if let Some(fri_ref) = fri.as_ref() {
1761                        fri_ref.remap_row_id(*id)
1762                    } else {
1763                        Some(*id)
1764                    }
1765                })
1766                .zip(num_tokens_col.values().iter())
1767                .sorted_unstable_by_key(|x| x.0)
1768                .unzip(),
1769            false => {
1770                let row_ids = row_id_col
1771                    .values()
1772                    .iter()
1773                    .filter_map(|id| {
1774                        if let Some(fri_ref) = fri.as_ref() {
1775                            fri_ref.remap_row_id(*id)
1776                        } else {
1777                            Some(*id)
1778                        }
1779                    })
1780                    .collect();
1781                let num_tokens = num_tokens_col.values().to_vec();
1782                (row_ids, num_tokens)
1783            }
1784        };
1785
1786        let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
1787        Ok(Self {
1788            row_ids,
1789            num_tokens,
1790            total_tokens,
1791        })
1792    }
1793
1794    // remap the row ids to the new row ids
1795    // returns the removed doc ids
1796    pub fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Vec<u32> {
1797        let mut removed = Vec::new();
1798        let len = self.len();
1799        let row_ids = std::mem::replace(&mut self.row_ids, Vec::with_capacity(len));
1800        let num_tokens = std::mem::replace(&mut self.num_tokens, Vec::with_capacity(len));
1801        for (doc_id, (row_id, num_token)) in std::iter::zip(row_ids, num_tokens).enumerate() {
1802            match mapping.get(&row_id) {
1803                Some(Some(new_row_id)) => {
1804                    self.row_ids.push(*new_row_id);
1805                    self.num_tokens.push(num_token);
1806                }
1807                Some(None) => {
1808                    removed.push(doc_id as u32);
1809                }
1810                None => {
1811                    self.row_ids.push(row_id);
1812                    self.num_tokens.push(num_token);
1813                }
1814            }
1815        }
1816        removed
1817    }
1818
1819    #[inline]
1820    pub fn num_tokens(&self, doc_id: u32) -> u32 {
1821        self.num_tokens[doc_id as usize]
1822    }
1823
1824    #[inline]
1825    pub fn num_tokens_by_row_id(&self, row_id: u64) -> u32 {
1826        self.row_ids
1827            .binary_search(&row_id)
1828            .map(|idx| self.num_tokens[idx])
1829            .unwrap_or(0)
1830    }
1831
1832    // append a document to the doc set
1833    // returns the doc_id (the number of documents before appending)
1834    pub fn append(&mut self, row_id: u64, num_tokens: u32) -> u32 {
1835        self.row_ids.push(row_id);
1836        self.num_tokens.push(num_tokens);
1837        self.total_tokens += num_tokens as u64;
1838        self.row_ids.len() as u32 - 1
1839    }
1840}
1841
1842pub fn flat_full_text_search(
1843    batches: &[&RecordBatch],
1844    doc_col: &str,
1845    query: &str,
1846    tokenizer: Option<tantivy::tokenizer::TextAnalyzer>,
1847) -> Result<Vec<u64>> {
1848    if batches.is_empty() {
1849        return Ok(vec![]);
1850    }
1851
1852    if is_phrase_query(query) {
1853        return Err(Error::invalid_input(
1854            "phrase query is not supported for flat full text search, try using FTS index",
1855            location!(),
1856        ));
1857    }
1858
1859    match batches[0][doc_col].data_type() {
1860        DataType::Utf8 => do_flat_full_text_search::<i32>(batches, doc_col, query, tokenizer),
1861        DataType::LargeUtf8 => do_flat_full_text_search::<i64>(batches, doc_col, query, tokenizer),
1862        data_type => Err(Error::invalid_input(
1863            format!("unsupported data type {} for inverted index", data_type),
1864            location!(),
1865        )),
1866    }
1867}
1868
1869fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
1870    batches: &[&RecordBatch],
1871    doc_col: &str,
1872    query: &str,
1873    tokenizer: Option<tantivy::tokenizer::TextAnalyzer>,
1874) -> Result<Vec<u64>> {
1875    let mut results = Vec::new();
1876    let mut tokenizer =
1877        tokenizer.unwrap_or_else(|| InvertedIndexParams::default().build().unwrap());
1878    let query_tokens = collect_tokens(query, &mut tokenizer, None)
1879        .into_iter()
1880        .collect::<HashSet<_>>();
1881
1882    for batch in batches {
1883        let row_id_array = batch[ROW_ID].as_primitive::<UInt64Type>();
1884        let doc_array = batch[doc_col].as_string::<Offset>();
1885        for i in 0..row_id_array.len() {
1886            let doc = doc_array.value(i);
1887            let doc_tokens = collect_tokens(doc, &mut tokenizer, Some(&query_tokens));
1888            if !doc_tokens.is_empty() {
1889                results.push(row_id_array.value(i));
1890                assert!(doc.contains(query));
1891            }
1892        }
1893    }
1894
1895    Ok(results)
1896}
1897
1898#[allow(clippy::too_many_arguments)]
1899pub fn flat_bm25_search(
1900    batch: RecordBatch,
1901    doc_col: &str,
1902    query_tokens: &HashSet<String>,
1903    nq: &HashMap<String, usize>,
1904    tokenizer: &mut tantivy::tokenizer::TextAnalyzer,
1905    avgdl: f32,
1906    num_docs: usize,
1907) -> std::result::Result<RecordBatch, DataFusionError> {
1908    let doc_iter = iter_str_array(&batch[doc_col]);
1909    let mut scores = Vec::with_capacity(batch.num_rows());
1910    for doc in doc_iter {
1911        let Some(doc) = doc else {
1912            scores.push(0.0);
1913            continue;
1914        };
1915
1916        let doc_tokens = collect_tokens(doc, tokenizer, Some(query_tokens));
1917        let doc_norm = K1 * (1.0 - B + B * doc_tokens.len() as f32 / avgdl);
1918        let mut doc_token_count = HashMap::new();
1919        for token in doc_tokens {
1920            doc_token_count
1921                .entry(token)
1922                .and_modify(|count| *count += 1)
1923                .or_insert(1);
1924        }
1925        let mut score = 0.0;
1926        for token in query_tokens.iter() {
1927            let freq = doc_token_count.get(token).copied().unwrap_or_default() as f32;
1928
1929            let idf = idf(nq[token], num_docs);
1930            score += idf * (freq * (K1 + 1.0) / (freq + doc_norm));
1931        }
1932        scores.push(score);
1933    }
1934
1935    let score_col = Arc::new(Float32Array::from(scores)) as ArrayRef;
1936    let batch = batch
1937        .try_with_column(SCORE_FIELD.clone(), score_col)?
1938        .project_by_schema(&FTS_SCHEMA)?; // the scan node would probably scan some extra columns for prefilter, drop them here
1939    Ok(batch)
1940}
1941
1942pub fn flat_bm25_search_stream(
1943    input: SendableRecordBatchStream,
1944    doc_col: String,
1945    query: String,
1946    index: &InvertedIndex,
1947) -> SendableRecordBatchStream {
1948    let mut tokenizer = index.tokenizer.clone();
1949    let tokens = collect_tokens(&query, &mut tokenizer, None)
1950        .into_iter()
1951        .sorted_unstable()
1952        .collect::<HashSet<_>>();
1953
1954    let bm25_scorer = BM25Scorer::new(index.partitions.iter().map(|p| p.as_ref()));
1955    let num_docs = bm25_scorer.num_docs();
1956    let avgdl = bm25_scorer.avgdl();
1957    let mut nq = HashMap::with_capacity(tokens.len());
1958    for token in &tokens {
1959        let token_nq = bm25_scorer.nq(token).max(1);
1960        nq.insert(token.clone(), token_nq);
1961    }
1962    let stream = input.map(move |batch| {
1963        let batch = batch?;
1964        let batch = flat_bm25_search(
1965            batch,
1966            &doc_col,
1967            &tokens,
1968            &nq,
1969            &mut tokenizer,
1970            avgdl,
1971            num_docs,
1972        )?;
1973
1974        // filter out rows with score 0
1975        let score_col = batch[SCORE_COL].as_primitive::<Float32Type>();
1976        let mask = score_col
1977            .iter()
1978            .map(|score| score.is_some_and(|score| score > 0.0))
1979            .collect::<Vec<_>>();
1980        let mask = BooleanArray::from(mask);
1981        let batch = arrow::compute::filter_record_batch(&batch, &mask)?;
1982        debug_assert!(batch[ROW_ID].null_count() == 0, "flat FTS produces nulls");
1983        Ok(batch)
1984    });
1985
1986    Box::pin(RecordBatchStreamAdapter::new(FTS_SCHEMA.clone(), stream)) as SendableRecordBatchStream
1987}
1988
1989pub fn is_phrase_query(query: &str) -> bool {
1990    query.starts_with('\"') && query.ends_with('\"')
1991}
1992
1993#[cfg(test)]
1994mod tests {
1995    use crate::scalar::inverted::encoding::decompress_posting_list;
1996
1997    use super::*;
1998
1999    #[tokio::test]
2000    async fn test_posting_builder_remap() {
2001        let mut builder = PostingListBuilder::new(false);
2002        let n = BLOCK_SIZE + 3;
2003        for i in 0..n {
2004            builder.add(i as u32, PositionRecorder::Count(1));
2005        }
2006        let removed = vec![5, 7];
2007        builder.remap(&removed);
2008
2009        let mut expected = PostingListBuilder::new(false);
2010        for i in 0..n - removed.len() {
2011            expected.add(i as u32, PositionRecorder::Count(1));
2012        }
2013        assert_eq!(builder.doc_ids, expected.doc_ids);
2014        assert_eq!(builder.frequencies, expected.frequencies);
2015
2016        // BLOCK_SIZE + 3 elements should be reduced to BLOCK_SIZE + 1,
2017        // there are still 2 blocks.
2018        let batch = builder.to_batch(vec![1.0, 2.0]).unwrap();
2019        let (doc_ids, freqs) = decompress_posting_list(
2020            (n - removed.len()) as u32,
2021            batch[POSTING_COL]
2022                .as_list::<i32>()
2023                .value(0)
2024                .as_binary::<i64>(),
2025        )
2026        .unwrap();
2027        assert!(doc_ids
2028            .iter()
2029            .zip(expected.doc_ids.iter())
2030            .all(|(a, b)| a == b));
2031        assert!(freqs
2032            .iter()
2033            .zip(expected.frequencies.iter())
2034            .all(|(a, b)| a == b));
2035    }
2036}