Skip to main content

lance_index/scalar/
ngram.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::any::Any;
5use std::collections::BTreeMap;
6use std::iter::once;
7use std::time::Instant;
8use std::{collections::HashMap, sync::Arc};
9
10use super::lance_format::LanceIndexStore;
11use super::{
12    AnyQuery, BuiltinIndexType, IndexReader, IndexStore, IndexWriter, MetricsCollector,
13    ScalarIndex, ScalarIndexParams, SearchResult, TextQuery,
14};
15use crate::frag_reuse::FragReuseIndex;
16use crate::metrics::NoOpMetricsCollector;
17use crate::pbold;
18use crate::scalar::expression::{ScalarQueryParser, TextQueryParser};
19use crate::scalar::registry::{
20    DefaultTrainingRequest, ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest,
21    VALUE_COLUMN_NAME,
22};
23use crate::scalar::{CreatedIndex, UpdateCriteria};
24use crate::vector::VectorIndex;
25use crate::{Index, IndexType};
26use arrow::array::{AsArray, UInt32Builder};
27use arrow::datatypes::{UInt32Type, UInt64Type};
28use arrow_array::{BinaryArray, RecordBatch, UInt32Array};
29use arrow_schema::{DataType, Field, Schema, SchemaRef};
30use async_trait::async_trait;
31use datafusion::execution::SendableRecordBatchStream;
32use deepsize::DeepSizeOf;
33use futures::{FutureExt, Stream, StreamExt, TryStreamExt, stream};
34use lance_arrow::iter_str_array;
35use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
36use lance_core::error::LanceOptionExt;
37use lance_core::utils::address::RowAddress;
38use lance_core::utils::tempfile::TempDir;
39use lance_core::utils::tokio::get_num_compute_intensive_cpus;
40use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS};
41use lance_core::{Error, utils::mask::RowAddrTreeMap};
42use lance_core::{ROW_ID, Result};
43use lance_io::object_store::ObjectStore;
44use lance_tokenizer::{
45    AlphaNumOnlyFilter, AsciiFoldingFilter, LowerCaser, NgramTokenizer, RawTokenizer, TextAnalyzer,
46};
47use log::info;
48use roaring::{RoaringBitmap, RoaringTreemap};
49use serde::Serialize;
50use tracing::instrument;
51
52const TOKENS_COL: &str = "tokens";
53const POSTING_LIST_COL: &str = "posting_list";
54const POSTINGS_FILENAME: &str = "ngram_postings.lance";
55const NGRAM_INDEX_VERSION: u32 = 0;
56
57use std::sync::LazyLock;
58
59pub static TOKENS_FIELD: LazyLock<Field> =
60    LazyLock::new(|| Field::new(TOKENS_COL, DataType::UInt32, true));
61pub static POSTINGS_FIELD: LazyLock<Field> =
62    LazyLock::new(|| Field::new(POSTING_LIST_COL, DataType::Binary, false));
63pub static POSTINGS_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
64    Arc::new(Schema::new(vec![
65        TOKENS_FIELD.clone(),
66        POSTINGS_FIELD.clone(),
67    ]))
68});
69pub static TEXT_PREPPER: LazyLock<TextAnalyzer> = LazyLock::new(|| {
70    TextAnalyzer::builder(RawTokenizer::default())
71        .filter(LowerCaser)
72        .filter(AsciiFoldingFilter)
73        .build()
74});
75/// Currently we ALWAYS use trigrams with ascii folding and lower casing.  We may want to make this configurable in the future.
76pub static NGRAM_TOKENIZER: LazyLock<TextAnalyzer> = LazyLock::new(|| {
77    TextAnalyzer::builder(NgramTokenizer::all_ngrams(3, 3).unwrap())
78        .filter(AlphaNumOnlyFilter)
79        .build()
80});
81
82// Helper function to apply a function to each token in a text
83fn tokenize_visitor(tokenizer: &TextAnalyzer, text: &str, mut visitor: impl FnMut(&String)) {
84    // The token_stream method is mutable.  As far as I can tell this is to enforce exclusivity and not
85    // true mutability.  For example, the object returned by `token_stream` has thread-local state but
86    // it is reset each time `token_stream` is called.
87    //
88    // However, I don't see this documented anywhere and I'm not sure about relying on it.  For now, we
89    // make a clone as that seems to be the safer option.  All the tokenizers we use here should be trivially
90    // cloneable (although it requires a heap allocation so may be worth investigating in the future)
91    let mut prepper = TEXT_PREPPER.clone();
92    let mut tokenizer = tokenizer.clone();
93    let mut raw_stream = prepper.token_stream(text);
94    while raw_stream.advance() {
95        let mut token_stream = tokenizer.token_stream(&raw_stream.token().text);
96        while token_stream.advance() {
97            visitor(&token_stream.token().text);
98        }
99    }
100}
101
102const ALPHA_SPAN: usize = 37;
103const MAX_TOKEN: usize = ALPHA_SPAN.pow(2) + ALPHA_SPAN;
104const MIN_TOKEN: usize = 0;
105const NGRAM_N: usize = 3;
106
107// Convert an ngram (string) to a token (u32).  This helps avoid heap allocations
108// and it makes it easier to partition the tokens for shuffling
109//
110// There are 36 alphanumeric values and we add 1 for the NULL token giving us 37^3
111// potential tokens.
112//
113// "" => 0
114// "?" => 37^2 * ?
115// "?$" => 37^2 * ? + 37 * $
116// "?$#" => 37^2 * ? + 37 * $ + #
117// ...
118//
119// The ?,$,# represent the position in the alphabet (+1 to distinguish from NULL)
120//
121// Small strings get the larger multipliers because those ngrams are
122// less likely to be unique and will have larger bitmaps.  We want to
123// spread those out.
124//
125// NOTE: Today we hard-code trigrams and we do not include 1-grams or 2-grams so this
126// function is more general than it needs to be...just in case.
127fn ngram_to_token(ngram: &str, ngram_length: usize) -> u32 {
128    let mut token = 0;
129    // Empty string will get 0
130    for (idx, byte) in ngram.bytes().enumerate() {
131        let pos = if byte <= b'9' {
132            byte - b'0'
133        } else if byte <= b'z' {
134            byte - b'a' + 10
135        } else {
136            unreachable!()
137        } + 1;
138        debug_assert!(pos < ALPHA_SPAN as u8);
139        let mult = ALPHA_SPAN.pow(ngram_length as u32 - idx as u32 - 1) as u32;
140        token += pos as u32 * mult;
141    }
142    token
143}
144
145/// Basic stats about an ngram index
146#[derive(Serialize)]
147struct NGramStatistics {
148    num_ngrams: usize,
149}
150
151/// The row ids that contain a given ngram
152#[derive(Debug)]
153pub struct NGramPostingList {
154    bitmap: RoaringTreemap,
155}
156
157impl DeepSizeOf for NGramPostingList {
158    fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
159        self.bitmap.serialized_size()
160    }
161}
162
163// Cache key implementation for type-safe cache access
164#[derive(Debug, Clone)]
165pub struct NGramPostingListKey {
166    pub row_offset: u32,
167}
168
169impl CacheKey for NGramPostingListKey {
170    type ValueType = NGramPostingList;
171
172    fn key(&self) -> std::borrow::Cow<'_, str> {
173        format!("posting-list-{}", self.row_offset).into()
174    }
175
176    fn type_name() -> &'static str {
177        "NGramPostingList"
178    }
179}
180
181impl NGramPostingList {
182    fn try_from_batch(
183        batch: RecordBatch,
184        frag_reuse_index: Option<Arc<FragReuseIndex>>,
185    ) -> Result<Self> {
186        let bitmap_bytes = batch.column(0).as_binary::<i32>().value(0);
187        let mut bitmap = RoaringTreemap::deserialize_from(bitmap_bytes)
188            .map_err(|e| Error::internal(format!("Error deserializing ngram list: {}", e)))?;
189        if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
190            bitmap = frag_reuse_index_ref.remap_row_ids_roaring_tree_map(&bitmap);
191        }
192        Ok(Self { bitmap })
193    }
194
195    fn intersect<'a>(lists: impl IntoIterator<Item = &'a Self>) -> RoaringTreemap {
196        let mut iter = lists.into_iter();
197        let mut result = iter
198            .next()
199            .map(|list| list.bitmap.clone())
200            .unwrap_or_default();
201        for list in iter {
202            result &= &list.bitmap;
203        }
204        result
205    }
206}
207
208/// Reads on-demand ngram posting lists from storage (and stores them in a cache)
209struct NGramPostingListReader {
210    reader: Arc<dyn IndexReader>,
211    frag_reuse_index: Option<Arc<FragReuseIndex>>,
212    index_cache: WeakLanceCache,
213}
214
215impl DeepSizeOf for NGramPostingListReader {
216    fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
217        0
218    }
219}
220
221impl std::fmt::Debug for NGramPostingListReader {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        f.debug_struct("NGramListReader").finish()
224    }
225}
226
227impl NGramPostingListReader {
228    #[instrument(level = "debug", skip(self, metrics))]
229    pub async fn ngram_list(
230        &self,
231        row_offset: u32,
232        metrics: &dyn MetricsCollector,
233    ) -> Result<Arc<NGramPostingList>> {
234        self.index_cache.get_or_insert_with_key(NGramPostingListKey { row_offset }, || async move {
235            metrics.record_part_load();
236                tracing::info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="ngram", part_id=row_offset);
237                let batch = self
238                    .reader
239                    .read_range(
240                        row_offset as usize..row_offset as usize + 1,
241                        Some(&[POSTING_LIST_COL]),
242                    )
243                    .await?;
244                NGramPostingList::try_from_batch(batch, self.frag_reuse_index.clone())
245        }).await
246    }
247}
248
249/// An ngram index
250///
251/// At a high level this is an inverted index that maps ngrams (small fixed size substrings) to the
252/// row ids that contain them.
253///
254/// As a simple example consider a 1-gram index.  It would basically be a mapping from
255/// each letter to the row ids that contain that letter.  Then, if the user searches for
256/// "cat", the index would look up the row ids for "c", "a", and "t", and return the intersection
257/// of those row ids because only rows have at least one c, a, and t could possible contain "cat".
258///
259/// This is an in-exact index, similar to a bloom filter.  It can return false positives and a
260/// recheck step is needed to confirm the results.
261///
262/// Note that it cannot return false negatives.
263pub struct NGramIndex {
264    /// The mapping from tokens to row offsets
265    tokens: HashMap<u32, u32>,
266    /// The reader for the posting lists
267    list_reader: Arc<NGramPostingListReader>,
268    /// The tokenizer used to tokenize text.  Note: not all tokenizers can be used with this index.  For
269    /// example, a stemming tokenizer would not work well because "dozing" would stem to "doze" and if the
270    /// search term is "zing" it would not match.  As a result, this tokenizer is not as configurable as the
271    /// tokenizers used in an inverted index.
272    tokenizer: TextAnalyzer,
273    io_parallelism: usize,
274    /// The store that owns the index
275    store: Arc<dyn IndexStore>,
276}
277
278impl std::fmt::Debug for NGramIndex {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        f.debug_struct("NGramIndex")
281            .field("tokens", &self.tokens)
282            .field("list_reader", &self.list_reader)
283            .finish()
284    }
285}
286
287impl DeepSizeOf for NGramIndex {
288    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
289        self.tokens.deep_size_of_children(context)
290    }
291}
292
293impl NGramIndex {
294    async fn from_store(
295        store: Arc<dyn IndexStore>,
296        frag_reuse_index: Option<Arc<FragReuseIndex>>,
297        index_cache: &LanceCache,
298    ) -> Result<Self> {
299        let tokens = store.open_index_file(POSTINGS_FILENAME).await?;
300        let tokens = tokens
301            .read_range(0..tokens.num_rows(), Some(&[TOKENS_COL]))
302            .await?;
303
304        let tokens_map = HashMap::from_iter(
305            tokens
306                .column(0)
307                .as_primitive::<UInt32Type>()
308                .values()
309                .iter()
310                .copied()
311                .enumerate()
312                .map(|(idx, token)| (token, idx as u32)),
313        );
314
315        let posting_reader = Arc::new(NGramPostingListReader {
316            reader: store.open_index_file(POSTINGS_FILENAME).await?,
317            frag_reuse_index,
318            index_cache: WeakLanceCache::from(index_cache),
319        });
320
321        Ok(Self {
322            io_parallelism: store.io_parallelism(),
323            tokens: tokens_map,
324            list_reader: posting_reader,
325            tokenizer: NGRAM_TOKENIZER.clone(),
326            store,
327        })
328    }
329
330    fn remap_batch(
331        &self,
332        batch: RecordBatch,
333        mapping: &HashMap<u64, Option<u64>>,
334    ) -> Result<RecordBatch> {
335        let posting_lists_array = batch
336            .column_by_name(POSTING_LIST_COL)
337            .expect_ok()?
338            .as_binary::<i32>();
339
340        let new_posting_lists = posting_lists_array
341            .iter()
342            .map(|posting_list| {
343                let posting_list = posting_list.unwrap();
344                let posting_list = RoaringTreemap::deserialize_from(posting_list)?;
345                let new_posting_list =
346                    RoaringTreemap::from_iter(posting_list.into_iter().filter_map(|row_id| {
347                        match mapping.get(&row_id) {
348                            Some(Some(new_row_id)) => Some(*new_row_id),
349                            Some(None) => None,
350                            None => Some(row_id),
351                        }
352                    }));
353                let mut buf = Vec::with_capacity(new_posting_list.serialized_size());
354                new_posting_list.serialize_into(&mut buf)?;
355                Ok(buf)
356            })
357            .collect::<Result<Vec<_>>>()?;
358
359        let new_posting_lists_array = BinaryArray::from_iter_values(new_posting_lists);
360
361        Ok(RecordBatch::try_new(
362            POSTINGS_SCHEMA.clone(),
363            vec![
364                batch.column_by_name(TOKENS_COL).expect_ok()?.clone(),
365                Arc::new(new_posting_lists_array),
366            ],
367        )?)
368    }
369
370    async fn load(
371        store: Arc<dyn IndexStore>,
372        frag_reuse_index: Option<Arc<FragReuseIndex>>,
373        index_cache: &LanceCache,
374    ) -> Result<Arc<Self>>
375    where
376        Self: Sized,
377    {
378        Ok(Arc::new(
379            Self::from_store(store, frag_reuse_index, index_cache).await?,
380        ))
381    }
382}
383
384#[async_trait]
385impl Index for NGramIndex {
386    fn as_any(&self) -> &dyn Any {
387        self
388    }
389
390    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
391        self
392    }
393
394    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn VectorIndex>> {
395        Err(Error::invalid_input_source(
396            "NGramIndex is not a vector index".into(),
397        ))
398    }
399
400    fn statistics(&self) -> Result<serde_json::Value> {
401        let ngram_stats = NGramStatistics {
402            num_ngrams: self.tokens.len(),
403        };
404        serde_json::to_value(ngram_stats)
405            .map_err(|e| Error::internal(format!("Error serializing statistics: {}", e)))
406    }
407
408    async fn prewarm(&self) -> Result<()> {
409        // TODO: NGram index can pre-warm by loading all posting lists into memory
410        Ok(())
411    }
412
413    fn index_type(&self) -> IndexType {
414        IndexType::NGram
415    }
416
417    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
418        let mut frag_ids = RoaringBitmap::new();
419        for row_offset in self.tokens.values() {
420            let list = self
421                .list_reader
422                .ngram_list(*row_offset, &NoOpMetricsCollector)
423                .await?;
424            frag_ids.extend(
425                list.bitmap
426                    .iter()
427                    .map(|row_addr| RowAddress::from(row_addr).fragment_id()),
428            );
429        }
430        Ok(frag_ids)
431    }
432}
433
434#[async_trait]
435impl ScalarIndex for NGramIndex {
436    async fn search(
437        &self,
438        query: &dyn AnyQuery,
439        metrics: &dyn MetricsCollector,
440    ) -> Result<SearchResult> {
441        let query = query
442            .as_any()
443            .downcast_ref::<TextQuery>()
444            .ok_or_else(|| Error::invalid_input_source("Query is not a TextQuery".into()))?;
445        match query {
446            TextQuery::StringContains(substr) => {
447                if substr.len() < NGRAM_N {
448                    // We know nothing on short searches, need to recheck all
449                    return Ok(SearchResult::at_least(RowAddrTreeMap::new()));
450                }
451
452                let mut row_offsets = Vec::with_capacity(substr.len() * 3);
453                let mut missing = false;
454                tokenize_visitor(&self.tokenizer, substr, |ngram| {
455                    let token = ngram_to_token(ngram, NGRAM_N);
456                    if let Some(row_offset) = self.tokens.get(&token) {
457                        row_offsets.push(*row_offset);
458                    } else {
459                        missing = true;
460                    }
461                });
462                // At least one token was missing, so we know there are zero results
463                if missing {
464                    return Ok(SearchResult::exact(RowAddrTreeMap::new()));
465                }
466                let posting_lists = futures::stream::iter(
467                    row_offsets
468                        .into_iter()
469                        .map(|row_offset| self.list_reader.ngram_list(row_offset, metrics)),
470                )
471                .buffer_unordered(self.io_parallelism)
472                .try_collect::<Vec<_>>()
473                .await?;
474                metrics.record_comparisons(posting_lists.len());
475                let list_refs = posting_lists.iter().map(|list| list.as_ref());
476                let row_ids = NGramPostingList::intersect(list_refs);
477                Ok(SearchResult::at_most(RowAddrTreeMap::from(row_ids)))
478            }
479        }
480    }
481
482    fn can_remap(&self) -> bool {
483        true
484    }
485
486    async fn remap(
487        &self,
488        mapping: &HashMap<u64, Option<u64>>,
489        dest_store: &dyn IndexStore,
490    ) -> Result<CreatedIndex> {
491        let reader = self.store.open_index_file(POSTINGS_FILENAME).await?;
492        let mut writer = dest_store
493            .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone())
494            .await?;
495
496        let mut offset = 0;
497        let num_rows = reader.num_rows();
498        const BATCH_SIZE: usize = 64;
499        while offset < num_rows {
500            let batch_size = BATCH_SIZE.min(num_rows - offset);
501            let batch = reader.read_range(offset..offset + batch_size, None).await?;
502            let batch = self.remap_batch(batch, mapping)?;
503            writer.write_record_batch(batch).await?;
504            offset += BATCH_SIZE;
505        }
506
507        writer.finish().await?;
508
509        Ok(CreatedIndex {
510            index_details: prost_types::Any::from_msg(&pbold::NGramIndexDetails::default())
511                .unwrap(),
512            index_version: NGRAM_INDEX_VERSION,
513            files: Some(dest_store.list_files_with_sizes().await?),
514        })
515    }
516
517    async fn update(
518        &self,
519        new_data: SendableRecordBatchStream,
520        dest_store: &dyn IndexStore,
521        _old_data_filter: Option<super::OldIndexDataFilter>,
522    ) -> Result<CreatedIndex> {
523        let mut builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default())?;
524        let spill_files = builder.train(new_data).await?;
525
526        builder
527            .write_index(dest_store, spill_files, Some(self.store.clone()))
528            .await?;
529
530        Ok(CreatedIndex {
531            index_details: prost_types::Any::from_msg(&pbold::NGramIndexDetails::default())
532                .unwrap(),
533            index_version: NGRAM_INDEX_VERSION,
534            files: Some(dest_store.list_files_with_sizes().await?),
535        })
536    }
537
538    fn update_criteria(&self) -> UpdateCriteria {
539        UpdateCriteria::only_new_data(TrainingCriteria::new(TrainingOrdering::None).with_row_id())
540    }
541
542    fn derive_index_params(&self) -> Result<ScalarIndexParams> {
543        Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::NGram))
544    }
545}
546
547#[derive(Debug, Clone)]
548pub struct NGramIndexBuilderOptions {
549    tokens_per_spill: usize,
550}
551
552// A higher value will use more RAM.  A lower value will have to do more spilling
553static DEFAULT_TOKENS_PER_SPILL: LazyLock<usize> = LazyLock::new(|| {
554    std::env::var("LANCE_NGRAM_TOKENS_PER_SPILL")
555        .unwrap_or_else(|_| "1000000000".to_string())
556        .parse()
557        .expect("failed to parse LANCE_NGRAM_TOKENS_PER_SPILL")
558});
559// How many partitions to use for shuffling out the work.  We slightly
560// over-allocate this since the amount of work per-partition is not uniform.
561//
562// Increasing this may increase the performance but it could increase RAM (since we will spill less often)
563// and could hurt performance (since there will be more files at the end for the final spill)
564static DEFAULT_NUM_PARTITIONS: LazyLock<usize> = LazyLock::new(|| {
565    std::env::var("LANCE_NGRAM_NUM_PARTITIONS")
566        .map(|s| s.parse().expect("failed to parse LANCE_NGRAM_PARALLELISM"))
567        .unwrap_or((get_num_compute_intensive_cpus() * 4).max(128))
568});
569// Just enough so that tokenizing is faster than I/O
570static DEFAULT_TOKENIZE_PARALLELISM: LazyLock<usize> = LazyLock::new(|| {
571    std::env::var("LANCE_NGRAM_TOKENIZE_PARALLELISM")
572        .map(|s| {
573            s.parse()
574                .expect("failed to parse LANCE_NGRAM_TOKENIZE_PARALLELISM")
575        })
576        .unwrap_or(8)
577});
578
579impl Default for NGramIndexBuilderOptions {
580    fn default() -> Self {
581        Self {
582            tokens_per_spill: *DEFAULT_TOKENS_PER_SPILL,
583        }
584    }
585}
586
587// An ordered list of tokens and bitmaps
588//
589// The `tokens` list is ordered by token value.  This makes it easier to merge spill files.
590struct NGramIndexSpillState {
591    tokens: UInt32Array,
592    bitmaps: Vec<RoaringTreemap>,
593}
594
595impl NGramIndexSpillState {
596    fn try_from_batch(batch: RecordBatch) -> Result<Self> {
597        let tokens = batch
598            .column_by_name(TOKENS_COL)
599            .expect_ok()?
600            .as_primitive::<UInt32Type>()
601            .clone();
602        let postings = batch
603            .column_by_name(POSTING_LIST_COL)
604            .expect_ok()?
605            .as_binary::<i32>();
606
607        let bitmaps = postings
608            .into_iter()
609            .map(|bytes| {
610                RoaringTreemap::deserialize_from(bytes.expect_ok()?)
611                    .map_err(|e| Error::internal(format!("Error deserializing ngram list: {}", e)))
612            })
613            .collect::<Result<Vec<_>>>()?;
614
615        Ok(Self { tokens, bitmaps })
616    }
617
618    fn try_into_batch(self) -> Result<RecordBatch> {
619        let bitmap_array = BinaryArray::from_iter_values(self.bitmaps.into_iter().map(|bitmap| {
620            let mut buf = Vec::with_capacity(bitmap.serialized_size());
621            bitmap.serialize_into(&mut buf).unwrap();
622            buf
623        }));
624        Ok(RecordBatch::try_new(
625            POSTINGS_SCHEMA.clone(),
626            vec![Arc::new(self.tokens), Arc::new(bitmap_array)],
627        )?)
628    }
629}
630
631// As we're building we create a map from ngram to row ids.  When this map gets too large
632// we spill it to disk.
633struct NGramIndexBuildState {
634    tokens_map: BTreeMap<u32, RoaringTreemap>,
635}
636
637impl NGramIndexBuildState {
638    fn starting() -> Self {
639        Self {
640            tokens_map: BTreeMap::new(),
641        }
642    }
643
644    fn take(&mut self) -> Self {
645        let mut taken = Self::starting();
646        std::mem::swap(&mut self.tokens_map, &mut taken.tokens_map);
647        taken
648    }
649
650    fn into_spill(self) -> NGramIndexSpillState {
651        // We can rely on these being in token order because of BTreeMap
652        let tokens = UInt32Array::from_iter_values(self.tokens_map.keys().copied());
653        let bitmaps = Vec::from_iter(self.tokens_map.into_values());
654
655        NGramIndexSpillState { bitmaps, tokens }
656    }
657}
658
659/// A builder for an ngram index
660///
661/// The builder is a small pipeline.  First, we read in the data and tokenize it.  This
662/// stage uses fan-out parallelism to tokenize the data because tokenization may be a little
663/// slower than I/O.
664///
665/// The second stage fans out much wider.  It partitions the tokens into a number of partitions.
666/// Each partition has a BTreemap that maps tokens to row ids.  The partitions then build up
667/// roaring treemaps.  When a partition gets too full it will spill to disk.
668///
669/// Once all the data is processed we spill all the parititons to disk and then we merge the
670/// spill files into a single index file.
671pub struct NGramIndexBuilder {
672    tokenizer: TextAnalyzer,
673    options: NGramIndexBuilderOptions,
674    tmpdir: Arc<TempDir>,
675    spill_store: Arc<dyn IndexStore>,
676
677    tokens_seen: usize,
678    worker_number: usize,
679    has_flushed: bool,
680
681    state: NGramIndexBuildState,
682}
683
684impl NGramIndexBuilder {
685    pub fn try_new(options: NGramIndexBuilderOptions) -> Result<Self> {
686        Self::from_state(NGramIndexBuildState::starting(), options)
687    }
688
689    fn clone_worker(&self, worker_number: usize) -> Self {
690        let mut bitmaps = Vec::with_capacity(36 * 36 * 36 + 1);
691        // Token 0 is always the NULL bitmap
692        bitmaps.push(RoaringTreemap::new());
693        Self {
694            tokenizer: self.tokenizer.clone(),
695            state: NGramIndexBuildState::starting(),
696            tmpdir: self.tmpdir.clone(),
697            spill_store: self.spill_store.clone(),
698            options: self.options.clone(),
699            tokens_seen: 0,
700            worker_number,
701            has_flushed: false,
702        }
703    }
704
705    fn from_state(state: NGramIndexBuildState, options: NGramIndexBuilderOptions) -> Result<Self> {
706        let tokenizer = NGRAM_TOKENIZER.clone();
707
708        let tmpdir = Arc::new(TempDir::default());
709        let spill_store = Arc::new(LanceIndexStore::new(
710            Arc::new(ObjectStore::local()),
711            tmpdir.obj_path(),
712            Arc::new(LanceCache::no_cache()),
713        ));
714
715        Ok(Self {
716            tokenizer,
717            state,
718            tmpdir,
719            spill_store,
720            options,
721            tokens_seen: 0,
722            worker_number: 0,
723            has_flushed: false,
724        })
725    }
726
727    fn validate_schema(schema: &Schema) -> Result<()> {
728        if schema.fields().len() != 2 {
729            return Err(Error::invalid_input_source(
730                "Ngram index schema must have exactly two fields".into(),
731            ));
732        }
733        let values_field = schema.field_with_name(VALUE_COLUMN_NAME)?;
734        if *values_field.data_type() != DataType::Utf8
735            && *values_field.data_type() != DataType::LargeUtf8
736        {
737            return Err(Error::invalid_input_source(
738                "First field in ngram index schema must be of type Utf8/LargeUtf8".into(),
739            ));
740        }
741        let row_id_field = schema.field_with_name(ROW_ID)?;
742        if *row_id_field.data_type() != DataType::UInt64 {
743            return Err(Error::invalid_input_source(
744                "Second field in ngram index schema must be of type UInt64".into(),
745            ));
746        }
747        Ok(())
748    }
749
750    async fn process_batch(&mut self, tokens_and_ids: Vec<(u32, u64)>) -> Result<()> {
751        let mut tokens_seen = 0;
752        for (token, row_id) in tokens_and_ids {
753            tokens_seen += 1;
754            // This would be a bit simpler with entry API but, at scale, the vast majority
755            // of cases will be a hit and we want to avoid cloning the string if we can.  So
756            // for now we do the double-hash.  We can simplify in the future with raw_entry
757            // when it stabilizes.
758            self.state
759                .tokens_map
760                .entry(token)
761                .or_default()
762                .insert(row_id);
763        }
764        self.tokens_seen += tokens_seen;
765        if self.tokens_seen >= self.options.tokens_per_spill {
766            let state = self.state.take();
767            self.flush(state).await?;
768        }
769        Ok(())
770    }
771
772    fn spill_filename(id: usize) -> String {
773        format!("spill-{}.lance", id)
774    }
775
776    fn tmp_spill_filename(id: usize) -> String {
777        format!("spill-{}.lance.tmp", id)
778    }
779
780    async fn flush(&mut self, state: NGramIndexBuildState) -> Result<bool> {
781        if self.tokens_seen == 0 {
782            assert!(state.tokens_map.is_empty());
783            return Ok(self.has_flushed);
784        }
785        self.tokens_seen = 0;
786        let spill_state = state.into_spill();
787        let flush_start = Instant::now();
788        // The primary builder should never flush
789        debug_assert_ne!(self.worker_number, 0);
790        if self.has_flushed {
791            info!("Merging flush for worker {}", self.worker_number);
792            // If we have flushed before then we need to merge with the spill file
793            let mut writer = self
794                .spill_store
795                .new_index_file(
796                    &Self::tmp_spill_filename(self.worker_number),
797                    POSTINGS_SCHEMA.clone(),
798                )
799                .await?;
800
801            let left_stream = stream::once(std::future::ready(Ok(spill_state)));
802            let right_stream =
803                Self::stream_spill(self.spill_store.clone(), self.worker_number).await?;
804            Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?;
805            drop(writer);
806            self.spill_store
807                .rename_index_file(
808                    &Self::tmp_spill_filename(self.worker_number),
809                    &Self::spill_filename(self.worker_number),
810                )
811                .await?;
812        } else {
813            // If we haven't flushed before we can just write to the spill file
814            info!("Initial flush for worker {}", self.worker_number);
815            self.has_flushed = true;
816            let writer = self
817                .spill_store
818                .new_index_file(
819                    &Self::spill_filename(self.worker_number),
820                    POSTINGS_SCHEMA.clone(),
821                )
822                .await?;
823            self.write(writer, spill_state).await?;
824        }
825        let flush_time = flush_start.elapsed();
826        info!(
827            "Flushed worker {} in {}ms",
828            self.worker_number,
829            flush_time.as_millis()
830        );
831        Ok(true)
832    }
833
834    fn tokenize_and_partition(
835        tokenizer: &TextAnalyzer,
836        batch: RecordBatch,
837        num_workers: usize,
838    ) -> Result<Vec<Vec<(u32, u64)>>> {
839        let text_iter = iter_str_array(batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?);
840        let row_id_col = batch
841            .column_by_name(ROW_ID)
842            .expect_ok()?
843            .as_primitive::<UInt64Type>();
844        // Guessing 1000 tokens per row to at least avoid some of the earlier allocations
845        let mut partitions = vec![Vec::with_capacity(batch.num_rows() * 1000); num_workers];
846        let divisor = (MAX_TOKEN - MIN_TOKEN) / num_workers;
847        for (text, row_id) in text_iter.zip(row_id_col.values()) {
848            if let Some(text) = text {
849                tokenize_visitor(tokenizer, text, |token| {
850                    let token = ngram_to_token(token, NGRAM_N);
851                    let partition_id = (token as usize).saturating_sub(MIN_TOKEN) / divisor;
852                    partitions[partition_id % num_workers].push((token, *row_id));
853                });
854            } else {
855                partitions[0].push((0, *row_id));
856            }
857        }
858        Ok(partitions)
859    }
860
861    pub async fn train(&mut self, data: SendableRecordBatchStream) -> Result<Vec<usize>> {
862        let schema = data.schema();
863        Self::validate_schema(schema.as_ref())?;
864
865        let num_workers = *DEFAULT_NUM_PARTITIONS;
866        let mut senders = Vec::with_capacity(num_workers);
867        let mut builders = Vec::with_capacity(num_workers);
868        for worker_idx in 0..num_workers {
869            let (send, mut recv) = tokio::sync::mpsc::channel(2);
870            senders.push(send);
871
872            let mut builder = self.clone_worker(worker_idx + 1);
873            let future = tokio::spawn(async move {
874                while let Some(partition) = recv.recv().await {
875                    builder.process_batch(partition).await?;
876                }
877                Result::Ok(builder)
878            });
879            builders.push(future);
880        }
881
882        let mut partitions_stream = data
883            .and_then(|batch| {
884                let tokenizer = self.tokenizer.clone();
885                std::future::ready(Ok(tokio::task::spawn(async move {
886                    Ok(Self::tokenize_and_partition(
887                        &tokenizer,
888                        batch,
889                        num_workers,
890                    )?)
891                })
892                .map(|res| res.unwrap())))
893            })
894            .try_buffer_unordered(*DEFAULT_TOKENIZE_PARALLELISM);
895
896        while let Some(partitions) = partitions_stream.try_next().await? {
897            for (part_idx, partition) in partitions.into_iter().enumerate() {
898                senders[part_idx].send(partition).await.unwrap();
899            }
900        }
901
902        std::mem::drop(senders);
903        let builders = futures::future::try_join_all(builders).await?;
904
905        // Final flush is serialized.  If we kick this off in parallel it can
906        // use a lot of memory.
907
908        let mut to_spill = Vec::with_capacity(builders.len());
909
910        for builder in builders {
911            let mut builder = builder?;
912            let state = builder.state.take();
913            if builder.flush(state).await? {
914                to_spill.push(builder.worker_number);
915            }
916        }
917
918        Ok(to_spill)
919    }
920
921    async fn write(
922        &mut self,
923        mut writer: Box<dyn IndexWriter>,
924        state: NGramIndexSpillState,
925    ) -> Result<()> {
926        writer.write_record_batch(state.try_into_batch()?).await?;
927        writer.finish().await?;
928
929        Ok(())
930    }
931
932    async fn stream_spill_reader(
933        reader: Arc<dyn IndexReader>,
934    ) -> Result<impl Stream<Item = Result<NGramIndexSpillState>>> {
935        let num_rows = reader.num_rows();
936
937        Ok(stream::try_unfold(0, move |offset| {
938            let reader = reader.clone();
939            async move {
940                // These are small batches but, in the worst case scenario, each row could
941                // be massive (up to 128MB per row at 1B rows) and we end up breaking memory
942                let batch_size = std::cmp::min(num_rows - offset, 64);
943                if batch_size == 0 {
944                    return Ok(None);
945                }
946                let batch = reader.read_range(offset..offset + batch_size, None).await?;
947                let state = NGramIndexSpillState::try_from_batch(batch)?;
948                let new_offset = offset + batch_size;
949                Ok(Some((state, new_offset)))
950            }
951            .boxed()
952        }))
953    }
954
955    async fn stream_spill(
956        spill_store: Arc<dyn IndexStore>,
957        id: usize,
958    ) -> Result<impl Stream<Item = Result<NGramIndexSpillState>>> {
959        let reader = spill_store
960            .open_index_file(&Self::spill_filename(id))
961            .await?;
962        Self::stream_spill_reader(reader).await
963    }
964
965    fn merge_spill_states(
966        left_opt: &mut Option<NGramIndexSpillState>,
967        right_opt: &mut Option<NGramIndexSpillState>,
968    ) -> NGramIndexSpillState {
969        let left = left_opt.take().unwrap();
970        let right = right_opt.take().unwrap();
971
972        let item_capacity = left.tokens.len() + right.tokens.len();
973        let mut merged_tokens = UInt32Builder::with_capacity(item_capacity);
974        let mut merged_bitmaps = Vec::with_capacity(left.bitmaps.len() + right.bitmaps.len());
975
976        let mut left_tokens = left.tokens.values().iter().copied();
977        let mut left_bitmaps = left.bitmaps.into_iter();
978        let mut right_tokens = right.tokens.values().iter().copied();
979        let mut right_bitmaps = right.bitmaps.into_iter();
980
981        let mut left_token = left_tokens.next();
982        let mut left_bitmap = left_bitmaps.next();
983        let mut right_token = right_tokens.next();
984        let mut right_bitmap = right_bitmaps.next();
985
986        while left_token.is_some() && right_token.is_some() {
987            let left_token_val = left_token.unwrap();
988            let right_token_val = right_token.unwrap();
989            match left_token_val.cmp(&right_token_val) {
990                std::cmp::Ordering::Less => {
991                    merged_tokens.append_value(left_token_val);
992                    merged_bitmaps.push(left_bitmap.unwrap());
993                    left_token = left_tokens.next();
994                    left_bitmap = left_bitmaps.next();
995                }
996                std::cmp::Ordering::Greater => {
997                    merged_tokens.append_value(right_token_val);
998                    merged_bitmaps.push(right_bitmap.unwrap());
999                    right_token = right_tokens.next();
1000                    right_bitmap = right_bitmaps.next();
1001                }
1002                std::cmp::Ordering::Equal => {
1003                    merged_tokens.append_value(left_token_val);
1004                    merged_bitmaps.push(left_bitmap.unwrap() | &right_bitmap.unwrap());
1005                    left_token = left_tokens.next();
1006                    left_bitmap = left_bitmaps.next();
1007                    right_token = right_tokens.next();
1008                    right_bitmap = right_bitmaps.next();
1009                }
1010            }
1011        }
1012
1013        let collect_remaining = |cur_token, tokens, cur_bitmap, bitmaps| {
1014            let tokens = UInt32Array::from_iter_values(once(cur_token).chain(tokens));
1015            let bitmaps = once(cur_bitmap).chain(bitmaps).collect::<Vec<_>>();
1016            NGramIndexSpillState { tokens, bitmaps }
1017        };
1018
1019        if let Some(left_token) = left_token {
1020            *left_opt = Some(collect_remaining(
1021                left_token,
1022                left_tokens,
1023                left_bitmap.unwrap(),
1024                left_bitmaps,
1025            ));
1026        } else {
1027            *left_opt = None;
1028        }
1029        if let Some(right_token) = right_token {
1030            *right_opt = Some(collect_remaining(
1031                right_token,
1032                right_tokens,
1033                right_bitmap.unwrap(),
1034                right_bitmaps,
1035            ));
1036        } else {
1037            *right_opt = None;
1038        }
1039
1040        NGramIndexSpillState {
1041            tokens: merged_tokens.finish(),
1042            bitmaps: merged_bitmaps,
1043        }
1044    }
1045
1046    async fn merge_spill_streams(
1047        mut left_stream: impl Stream<Item = Result<NGramIndexSpillState>> + Unpin,
1048        mut right_stream: impl Stream<Item = Result<NGramIndexSpillState>> + Unpin,
1049        writer: &mut dyn IndexWriter,
1050    ) -> Result<()> {
1051        let mut left_state = left_stream.try_next().await?;
1052        let mut right_state = right_stream.try_next().await?;
1053
1054        while left_state.is_some() || right_state.is_some() {
1055            if left_state.is_none() {
1056                // Left is done, full drain right
1057                let state = right_state.take().expect_ok()?;
1058                writer.write_record_batch(state.try_into_batch()?).await?;
1059                while let Some(state) = right_stream.try_next().await? {
1060                    writer.write_record_batch(state.try_into_batch()?).await?;
1061                }
1062            } else if right_state.is_none() {
1063                // Right is done, full drain left
1064                let state = left_state.take().expect_ok()?;
1065                writer.write_record_batch(state.try_into_batch()?).await?;
1066                while let Some(state) = left_stream.try_next().await? {
1067                    writer.write_record_batch(state.try_into_batch()?).await?;
1068                }
1069            } else {
1070                // There is a batch from both left and right.  Need to merge them
1071                let merged = Self::merge_spill_states(&mut left_state, &mut right_state);
1072                writer.write_record_batch(merged.try_into_batch()?).await?;
1073                if left_state.is_none() {
1074                    left_state = left_stream.try_next().await?;
1075                }
1076                if right_state.is_none() {
1077                    right_state = right_stream.try_next().await?;
1078                }
1079            }
1080        }
1081
1082        writer.finish().await
1083    }
1084
1085    async fn merge_spill_files(
1086        spill_store: Arc<dyn IndexStore>,
1087        index_of_left: usize,
1088        index_of_right: usize,
1089        output_index: usize,
1090    ) -> Result<()> {
1091        // We fully load the small file into memory and then stream the large file
1092        info!(
1093            "Merge spill files {} and {} into {}",
1094            index_of_left, index_of_right, output_index
1095        );
1096
1097        let mut writer = spill_store
1098            .new_index_file(&Self::spill_filename(output_index), POSTINGS_SCHEMA.clone())
1099            .await?;
1100
1101        let (left_stream, right_stream) = futures::try_join!(
1102            Self::stream_spill(spill_store.clone(), index_of_left),
1103            Self::stream_spill(spill_store.clone(), index_of_right)
1104        )?;
1105
1106        Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?;
1107
1108        spill_store
1109            .delete_index_file(&Self::spill_filename(index_of_left))
1110            .await?;
1111        spill_store
1112            .delete_index_file(&Self::spill_filename(index_of_right))
1113            .await?;
1114
1115        Ok(())
1116    }
1117
1118    // Can potentially parallelize in the future if this step becomes a bottleneck
1119    //
1120    // We can also merge in a more balanced fashion (e.g. binary tree) to reduce the size of
1121    // intermediate files
1122    //
1123    // Note: worker indices start at 1 and not 0 (hence all the +1's)
1124    async fn merge_spills(&mut self, mut spill_files: Vec<usize>) -> Result<usize> {
1125        info!(
1126            "Merging {} index files into one combined index",
1127            spill_files.len()
1128        );
1129
1130        let mut spill_counter = spill_files.iter().max().expect_ok()? + 1;
1131        while spill_files.len() > 1 {
1132            let mut new_spills = Vec::with_capacity(spill_files.len() / 2);
1133            while spill_files.len() >= 2 {
1134                let left = spill_files.pop().expect_ok()?;
1135                let right = spill_files.pop().expect_ok()?;
1136                new_spills.push(tokio::spawn(Self::merge_spill_files(
1137                    self.spill_store.clone(),
1138                    left,
1139                    right,
1140                    spill_counter + new_spills.len(),
1141                )));
1142            }
1143            for i in 0..new_spills.len() {
1144                spill_files.push(spill_counter + i);
1145            }
1146            spill_counter += new_spills.len();
1147            futures::future::try_join_all(new_spills).await?;
1148        }
1149
1150        spill_files.pop().expect_ok()
1151    }
1152
1153    async fn merge_old_index(
1154        &mut self,
1155        new_data_num: usize,
1156        old_index: Arc<dyn IndexStore>,
1157    ) -> Result<usize> {
1158        info!("Merging old index into new index");
1159        let final_num = new_data_num + 1;
1160
1161        let mut writer = self
1162            .spill_store
1163            .new_index_file(&Self::spill_filename(final_num), POSTINGS_SCHEMA.clone())
1164            .await?;
1165
1166        let left_stream = Self::stream_spill(self.spill_store.clone(), new_data_num).await?;
1167        let old_reader = old_index.open_index_file(POSTINGS_FILENAME).await?;
1168        let right_stream = Self::stream_spill_reader(old_reader).await?;
1169
1170        Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?;
1171
1172        self.spill_store
1173            .delete_index_file(&Self::spill_filename(new_data_num))
1174            .await?;
1175
1176        Ok(final_num)
1177    }
1178
1179    pub async fn write_index(
1180        mut self,
1181        store: &dyn IndexStore,
1182        spill_files: Vec<usize>,
1183        old_index: Option<Arc<dyn IndexStore>>,
1184    ) -> Result<()> {
1185        let mut writer = store
1186            .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone())
1187            .await?;
1188
1189        if spill_files.is_empty() {
1190            if let Some(old_index) = old_index {
1191                // An update with no new data, just copy the old index to the new store
1192                old_index.copy_index_file(POSTINGS_FILENAME, store).await?;
1193            } else {
1194                // Training an index with no data, make an empty index
1195                let mut writer = store
1196                    .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone())
1197                    .await?;
1198                writer.finish().await?;
1199            }
1200            return Ok(());
1201        }
1202
1203        let mut index_to_copy = self.merge_spills(spill_files).await?;
1204
1205        if let Some(old_index) = old_index {
1206            index_to_copy = self.merge_old_index(index_to_copy, old_index).await?;
1207        }
1208
1209        let reader = self
1210            .spill_store
1211            .open_index_file(&Self::spill_filename(index_to_copy))
1212            .await?;
1213
1214        let num_rows = reader.num_rows();
1215        let mut offset = 0;
1216
1217        while offset < num_rows {
1218            let batch_size = std::cmp::min(num_rows - offset, 64);
1219            let batch = reader.read_range(offset..offset + batch_size, None).await?;
1220            writer.write_record_batch(batch).await?;
1221            offset += batch_size;
1222        }
1223
1224        writer.finish().await
1225    }
1226}
1227
1228#[derive(Debug, Default)]
1229pub struct NGramIndexPlugin;
1230
1231impl NGramIndexPlugin {
1232    pub async fn train_ngram_index(
1233        batches_source: SendableRecordBatchStream,
1234        index_store: &dyn IndexStore,
1235    ) -> Result<()> {
1236        let mut builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default())?;
1237
1238        let spill_files = builder.train(batches_source).await?;
1239
1240        builder.write_index(index_store, spill_files, None).await
1241    }
1242}
1243
1244#[async_trait]
1245impl ScalarIndexPlugin for NGramIndexPlugin {
1246    fn name(&self) -> &str {
1247        "NGram"
1248    }
1249
1250    fn new_training_request(
1251        &self,
1252        _params: &str,
1253        field: &Field,
1254    ) -> Result<Box<dyn TrainingRequest>> {
1255        if !matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
1256            return Err(Error::invalid_input_source(format!(
1257                "A ngram index can only be created on a Utf8 or LargeUtf8 field.  Column has type {:?}",
1258                field.data_type()
1259            )
1260            .into()));
1261        }
1262        Ok(Box::new(DefaultTrainingRequest::new(
1263            TrainingCriteria::new(TrainingOrdering::None).with_row_id(),
1264        )))
1265    }
1266
1267    fn provides_exact_answer(&self) -> bool {
1268        false
1269    }
1270
1271    fn version(&self) -> u32 {
1272        NGRAM_INDEX_VERSION
1273    }
1274
1275    fn new_query_parser(
1276        &self,
1277        index_name: String,
1278        _index_details: &prost_types::Any,
1279    ) -> Option<Box<dyn ScalarQueryParser>> {
1280        Some(Box::new(TextQueryParser::new(
1281            index_name,
1282            self.name().to_string(),
1283            true,
1284        )))
1285    }
1286
1287    async fn train_index(
1288        &self,
1289        data: SendableRecordBatchStream,
1290        index_store: &dyn IndexStore,
1291        _request: Box<dyn TrainingRequest>,
1292        fragment_ids: Option<Vec<u32>>,
1293        _progress: Arc<dyn crate::progress::IndexBuildProgress>,
1294    ) -> Result<CreatedIndex> {
1295        if fragment_ids.is_some() {
1296            return Err(Error::invalid_input_source(
1297                "NGram index does not support fragment training".into(),
1298            ));
1299        }
1300
1301        Self::train_ngram_index(data, index_store).await?;
1302        Ok(CreatedIndex {
1303            index_details: prost_types::Any::from_msg(&pbold::NGramIndexDetails::default())
1304                .unwrap(),
1305            index_version: NGRAM_INDEX_VERSION,
1306            files: Some(index_store.list_files_with_sizes().await?),
1307        })
1308    }
1309
1310    async fn load_index(
1311        &self,
1312        index_store: Arc<dyn IndexStore>,
1313        _index_details: &prost_types::Any,
1314        frag_reuse_index: Option<Arc<FragReuseIndex>>,
1315        cache: &LanceCache,
1316    ) -> Result<Arc<dyn ScalarIndex>> {
1317        Ok(NGramIndex::load(index_store, frag_reuse_index, cache).await? as Arc<dyn ScalarIndex>)
1318    }
1319}
1320
1321#[cfg(test)]
1322mod tests {
1323    use std::{
1324        collections::{HashMap, HashSet},
1325        sync::Arc,
1326    };
1327
1328    use arrow::datatypes::UInt64Type;
1329    use arrow_array::{Array, RecordBatch, StringArray, UInt64Array};
1330    use arrow_schema::{DataType, Field, Schema};
1331    use datafusion::{
1332        execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter,
1333    };
1334    use datafusion_common::DataFusionError;
1335    use futures::{TryStreamExt, stream};
1336    use itertools::Itertools;
1337    use lance_core::{
1338        ROW_ID,
1339        cache::LanceCache,
1340        utils::{mask::RowAddrTreeMap, tempfile::TempDir},
1341    };
1342    use lance_datagen::{BatchCount, ByteCount, RowCount};
1343    use lance_io::object_store::ObjectStore;
1344    use lance_tokenizer::TextAnalyzer;
1345
1346    use crate::scalar::{
1347        ScalarIndex, SearchResult, TextQuery,
1348        lance_format::LanceIndexStore,
1349        ngram::{NGramIndex, NGramIndexBuilder, NGramIndexBuilderOptions},
1350    };
1351    use crate::{metrics::NoOpMetricsCollector, scalar::registry::VALUE_COLUMN_NAME};
1352
1353    use super::{NGRAM_TOKENIZER, ngram_to_token, tokenize_visitor};
1354
1355    fn collect_tokens(analyzer: &TextAnalyzer, text: &str) -> Vec<String> {
1356        let mut tokens = Vec::with_capacity(text.len() * 3);
1357        tokenize_visitor(analyzer, text, |token| tokens.push(token.to_owned()));
1358        tokens
1359    }
1360
1361    #[test]
1362    fn test_tokenizer() {
1363        let tokenizer = NGRAM_TOKENIZER.clone();
1364
1365        // ASCII folding
1366        let tokens = collect_tokens(&tokenizer, "café");
1367        assert_eq!(
1368            tokens,
1369            vec!["caf", "afe"] // spellchecker:disable-line
1370        );
1371
1372        // Allow numbers
1373        let tokens = collect_tokens(&tokenizer, "a1b2");
1374        assert_eq!(tokens, vec!["a1b", "1b2"]);
1375
1376        // Remove symbols and UTF-8 that doesn't map to characters
1377        let tokens = collect_tokens(&tokenizer, "abc👍b!c24");
1378
1379        assert_eq!(tokens, vec!["abc", "c24"]);
1380
1381        let tokens = collect_tokens(&tokenizer, "anstoß");
1382
1383        assert_eq!(tokens, vec!["ans", "nst", "sto", "tos", "oss"]);
1384
1385        // Lower casing
1386        let tokens = collect_tokens(&tokenizer, "ABC");
1387        assert_eq!(tokens, vec!["abc"]);
1388
1389        // Duplicate tokens
1390        let tokens = collect_tokens(&tokenizer, "ababab");
1391        // Confirming that the tokenizer doesn't deduplicate tokens (this can be taken into consideration
1392        // when training the index)
1393        assert_eq!(
1394            tokens,
1395            vec!["aba", "bab", "aba", "bab"] // spellchecker:disable-line
1396        );
1397    }
1398
1399    async fn do_train(
1400        mut builder: NGramIndexBuilder,
1401        data: SendableRecordBatchStream,
1402    ) -> (NGramIndex, Arc<TempDir>) {
1403        let spill_files = builder.train(data).await.unwrap();
1404
1405        let tmpdir = Arc::new(TempDir::default());
1406        let test_store = LanceIndexStore::new(
1407            Arc::new(ObjectStore::local()),
1408            tmpdir.obj_path(),
1409            Arc::new(LanceCache::no_cache()),
1410        );
1411
1412        builder
1413            .write_index(&test_store, spill_files, None)
1414            .await
1415            .unwrap();
1416
1417        (
1418            NGramIndex::from_store(Arc::new(test_store), None, &LanceCache::no_cache())
1419                .await
1420                .unwrap(),
1421            tmpdir,
1422        )
1423    }
1424
1425    async fn get_posting_list_for_trigram(index: &NGramIndex, trigram: &str) -> Vec<u64> {
1426        let token = ngram_to_token(trigram, 3);
1427        let row_offset = index.tokens[&token];
1428        let list = index
1429            .list_reader
1430            .ngram_list(row_offset, &NoOpMetricsCollector)
1431            .await
1432            .unwrap();
1433        list.bitmap.iter().sorted().collect()
1434    }
1435
1436    async fn get_null_posting_list(index: &NGramIndex) -> Vec<u64> {
1437        let row_offset = index.tokens[&0];
1438        let list = index
1439            .list_reader
1440            .ngram_list(row_offset, &NoOpMetricsCollector)
1441            .await
1442            .unwrap();
1443        list.bitmap.iter().sorted().collect()
1444    }
1445
1446    #[test_log::test(tokio::test)]
1447    async fn test_basic_ngram_index() {
1448        let data = StringArray::from_iter_values([
1449            "cat",
1450            "dog",
1451            "cat dog",
1452            "dog cat",
1453            "elephant",
1454            "mouse",
1455            "rhino",
1456            "giraffe",
1457            "rhinos nose",
1458        ]);
1459        let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64));
1460        let schema = Arc::new(Schema::new(vec![
1461            Field::new(VALUE_COLUMN_NAME, DataType::Utf8, false),
1462            Field::new(ROW_ID, DataType::UInt64, false),
1463        ]));
1464        let data =
1465            RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap();
1466        let data = Box::pin(RecordBatchStreamAdapter::new(
1467            schema,
1468            stream::once(std::future::ready(Ok(data))),
1469        ));
1470
1471        let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1472
1473        let (index, _tmpdir) = do_train(builder, data).await;
1474        assert_eq!(index.tokens.len(), 21);
1475
1476        // Basic search
1477        let res = index
1478            .search(
1479                &TextQuery::StringContains("cat".to_string()),
1480                &NoOpMetricsCollector,
1481            )
1482            .await
1483            .unwrap();
1484
1485        let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([0, 2, 3]));
1486
1487        assert_eq!(expected, res);
1488
1489        // Whitespace in query
1490        let res = index
1491            .search(
1492                &TextQuery::StringContains("nos nos".to_string()),
1493                &NoOpMetricsCollector,
1494            )
1495            .await
1496            .unwrap();
1497        let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8]));
1498        assert_eq!(expected, res);
1499
1500        // No matches
1501        let res = index
1502            .search(
1503                &TextQuery::StringContains("tdo".to_string()),
1504                &NoOpMetricsCollector,
1505            )
1506            .await
1507            .unwrap();
1508        let expected = SearchResult::exact(RowAddrTreeMap::new());
1509        assert_eq!(expected, res);
1510
1511        // False positive
1512        let res = index
1513            .search(
1514                &TextQuery::StringContains("inose".to_string()),
1515                &NoOpMetricsCollector,
1516            )
1517            .await
1518            .unwrap();
1519        let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8]));
1520        assert_eq!(expected, res);
1521
1522        // Too short, don't know anything
1523        let res = index
1524            .search(
1525                &TextQuery::StringContains("ab".to_string()),
1526                &NoOpMetricsCollector,
1527            )
1528            .await
1529            .unwrap();
1530        let expected = SearchResult::at_least(RowAddrTreeMap::new());
1531        assert_eq!(expected, res);
1532
1533        // One short string but we still get at least one trigram, this is ok
1534        let res = index
1535            .search(
1536                &TextQuery::StringContains("no nos".to_string()),
1537                &NoOpMetricsCollector,
1538            )
1539            .await
1540            .unwrap();
1541        let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8]));
1542        assert_eq!(expected, res);
1543    }
1544
1545    fn test_data_schema() -> Arc<Schema> {
1546        Arc::new(Schema::new(vec![
1547            Field::new(VALUE_COLUMN_NAME, DataType::Utf8, true),
1548            Field::new(ROW_ID, DataType::UInt64, false),
1549        ]))
1550    }
1551
1552    fn simple_data_with_nulls() -> SendableRecordBatchStream {
1553        let data = StringArray::from_iter(&[Some("cat"), Some("dog"), None, None, Some("cat dog")]);
1554        let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64));
1555        let schema = test_data_schema();
1556        let data =
1557            RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap();
1558        Box::pin(RecordBatchStreamAdapter::new(
1559            schema,
1560            stream::once(std::future::ready(Ok(data))),
1561        ))
1562    }
1563
1564    #[test_log::test(tokio::test)]
1565    async fn test_ngram_nulls() {
1566        let data = simple_data_with_nulls();
1567
1568        let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1569
1570        let (index, _tmpdir) = do_train(builder, data).await;
1571        assert_eq!(index.tokens.len(), 3);
1572
1573        let res = index
1574            .search(
1575                &TextQuery::StringContains("cat".to_string()),
1576                &NoOpMetricsCollector,
1577            )
1578            .await
1579            .unwrap();
1580        let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([0, 4]));
1581        assert_eq!(expected, res);
1582
1583        let null_posting_list = get_null_posting_list(&index).await;
1584        assert_eq!(null_posting_list, vec![2, 3]);
1585
1586        // TODO: Support IS NULL queries
1587    }
1588
1589    fn empty_data() -> SendableRecordBatchStream {
1590        Box::pin(RecordBatchStreamAdapter::new(
1591            test_data_schema(),
1592            stream::empty::<lance_core::error::DataFusionResult<RecordBatch>>(),
1593        ))
1594    }
1595
1596    #[test_log::test(tokio::test)]
1597    async fn test_train_empty() {
1598        let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1599
1600        let (index, _tmpdir) = do_train(builder, empty_data()).await;
1601        assert_eq!(index.tokens.len(), 0);
1602    }
1603
1604    #[test_log::test(tokio::test)]
1605    async fn test_update_empty() {
1606        let data = simple_data_with_nulls();
1607
1608        let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1609        let (index, _tmpdir) = do_train(builder, empty_data()).await;
1610
1611        let new_tmpdir = Arc::new(TempDir::default());
1612        let test_store = Arc::new(LanceIndexStore::new(
1613            Arc::new(ObjectStore::local()),
1614            new_tmpdir.obj_path(),
1615            Arc::new(LanceCache::no_cache()),
1616        ));
1617
1618        index.update(data, test_store.as_ref(), None).await.unwrap();
1619
1620        let index = NGramIndex::from_store(test_store, None, &LanceCache::no_cache())
1621            .await
1622            .unwrap();
1623        assert_eq!(index.tokens.len(), 3);
1624    }
1625
1626    async fn row_ids_in_index(index: &NGramIndex) -> Vec<u64> {
1627        let mut row_ids = HashSet::new();
1628        for row_offset in index.tokens.values() {
1629            let list = index
1630                .list_reader
1631                .ngram_list(*row_offset, &NoOpMetricsCollector)
1632                .await
1633                .unwrap();
1634            row_ids.extend(list.bitmap.iter());
1635        }
1636        row_ids.into_iter().sorted().collect()
1637    }
1638
1639    #[test_log::test(tokio::test)]
1640    async fn test_ngram_index_remap() {
1641        let data = simple_data_with_nulls();
1642        let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1643        let (index, _tmpdir) = do_train(builder, data).await;
1644
1645        let row_ids = row_ids_in_index(&index).await;
1646        assert_eq!(row_ids, vec![0, 1, 2, 3, 4]);
1647
1648        let new_tmpdir = Arc::new(TempDir::default());
1649        let test_store = Arc::new(LanceIndexStore::new(
1650            Arc::new(ObjectStore::local()),
1651            new_tmpdir.obj_path(),
1652            Arc::new(LanceCache::no_cache()),
1653        ));
1654
1655        let remapping = HashMap::from([(2, Some(100)), (3, None), (4, Some(101))]);
1656        index.remap(&remapping, test_store.as_ref()).await.unwrap();
1657
1658        let index = NGramIndex::from_store(test_store, None, &LanceCache::no_cache())
1659            .await
1660            .unwrap();
1661        let row_ids = row_ids_in_index(&index).await;
1662        assert_eq!(row_ids, vec![0, 1, 100, 101]);
1663
1664        let null_posting_list = get_null_posting_list(&index).await;
1665        assert_eq!(null_posting_list, vec![100]);
1666    }
1667
1668    #[test_log::test(tokio::test)]
1669    async fn test_ngram_index_merge() {
1670        let data = simple_data_with_nulls();
1671        let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1672        let (index, _tmpdir) = do_train(builder, data).await;
1673
1674        let data = StringArray::from_iter(&[Some("giraffe"), Some("cat"), None]);
1675        let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64 + 100));
1676        let schema = Arc::new(Schema::new(vec![
1677            Field::new(VALUE_COLUMN_NAME, DataType::Utf8, true),
1678            Field::new(ROW_ID, DataType::UInt64, false),
1679        ]));
1680        let data =
1681            RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap();
1682        let data = Box::pin(RecordBatchStreamAdapter::new(
1683            schema,
1684            stream::once(std::future::ready(Ok(data))),
1685        ));
1686
1687        let posting_list = get_posting_list_for_trigram(&index, "cat").await;
1688        assert_eq!(posting_list, vec![0, 4]);
1689
1690        let new_tmpdir = Arc::new(TempDir::default());
1691        let test_store = Arc::new(LanceIndexStore::new(
1692            Arc::new(ObjectStore::local()),
1693            new_tmpdir.obj_path(),
1694            Arc::new(LanceCache::no_cache()),
1695        ));
1696
1697        index.update(data, test_store.as_ref(), None).await.unwrap();
1698
1699        let index = NGramIndex::from_store(test_store, None, &LanceCache::no_cache())
1700            .await
1701            .unwrap();
1702        let row_ids = row_ids_in_index(&index).await;
1703        assert_eq!(row_ids, vec![0, 1, 2, 3, 4, 100, 101, 102]);
1704
1705        let posting_list = get_posting_list_for_trigram(&index, "cat").await;
1706        assert_eq!(posting_list, vec![0, 4, 101]);
1707
1708        let posting_list = get_posting_list_for_trigram(&index, "ffe").await;
1709        assert_eq!(posting_list, vec![100]);
1710
1711        let posting_list = get_null_posting_list(&index).await;
1712        assert_eq!(posting_list, vec![2, 3, 102]);
1713    }
1714
1715    #[test_log::test(tokio::test)]
1716    async fn test_ngram_index_with_spill() {
1717        let (data, schema) = lance_datagen::gen_batch()
1718            .col(
1719                VALUE_COLUMN_NAME,
1720                lance_datagen::array::rand_utf8(ByteCount::from(50), false),
1721            )
1722            .col(ROW_ID, lance_datagen::array::step::<UInt64Type>())
1723            .into_reader_stream(RowCount::from(128), BatchCount::from(32));
1724
1725        let data = Box::pin(RecordBatchStreamAdapter::new(
1726            schema,
1727            data.map_err(|arrow_err| DataFusionError::ArrowError(Box::new(arrow_err), None)),
1728        ));
1729
1730        let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions {
1731            tokens_per_spill: 100,
1732        })
1733        .unwrap();
1734
1735        let (index, _tmpdir) = do_train(builder, data).await;
1736
1737        assert_eq!(index.tokens.len(), 29012);
1738    }
1739}