1use std::any::Any;
5use std::collections::BTreeMap;
6use std::iter::once;
7use std::time::Instant;
8use std::{collections::HashMap, sync::Arc};
9
10use super::lance_format::LanceIndexStore;
11use super::{
12 AnyQuery, BuiltinIndexType, IndexReader, IndexStore, IndexWriter, MetricsCollector,
13 ScalarIndex, ScalarIndexParams, SearchResult, TextQuery,
14};
15use crate::frag_reuse::FragReuseIndex;
16use crate::metrics::NoOpMetricsCollector;
17use crate::pbold;
18use crate::scalar::expression::{ScalarQueryParser, TextQueryParser};
19use crate::scalar::registry::{
20 DefaultTrainingRequest, ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest,
21 VALUE_COLUMN_NAME,
22};
23use crate::scalar::{CreatedIndex, UpdateCriteria};
24use crate::vector::VectorIndex;
25use crate::{Index, IndexType};
26use arrow::array::{AsArray, UInt32Builder};
27use arrow::datatypes::{UInt32Type, UInt64Type};
28use arrow_array::{BinaryArray, RecordBatch, UInt32Array};
29use arrow_schema::{DataType, Field, Schema, SchemaRef};
30use async_trait::async_trait;
31use datafusion::execution::SendableRecordBatchStream;
32use deepsize::DeepSizeOf;
33use futures::{FutureExt, Stream, StreamExt, TryStreamExt, stream};
34use lance_arrow::iter_str_array;
35use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
36use lance_core::error::LanceOptionExt;
37use lance_core::utils::address::RowAddress;
38use lance_core::utils::tempfile::TempDir;
39use lance_core::utils::tokio::get_num_compute_intensive_cpus;
40use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS};
41use lance_core::{Error, utils::mask::RowAddrTreeMap};
42use lance_core::{ROW_ID, Result};
43use lance_io::object_store::ObjectStore;
44use lance_tokenizer::{
45 AlphaNumOnlyFilter, AsciiFoldingFilter, LowerCaser, NgramTokenizer, RawTokenizer, TextAnalyzer,
46};
47use log::info;
48use roaring::{RoaringBitmap, RoaringTreemap};
49use serde::Serialize;
50use tracing::instrument;
51
52const TOKENS_COL: &str = "tokens";
53const POSTING_LIST_COL: &str = "posting_list";
54const POSTINGS_FILENAME: &str = "ngram_postings.lance";
55const NGRAM_INDEX_VERSION: u32 = 0;
56
57use std::sync::LazyLock;
58
59pub static TOKENS_FIELD: LazyLock<Field> =
60 LazyLock::new(|| Field::new(TOKENS_COL, DataType::UInt32, true));
61pub static POSTINGS_FIELD: LazyLock<Field> =
62 LazyLock::new(|| Field::new(POSTING_LIST_COL, DataType::Binary, false));
63pub static POSTINGS_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
64 Arc::new(Schema::new(vec![
65 TOKENS_FIELD.clone(),
66 POSTINGS_FIELD.clone(),
67 ]))
68});
69pub static TEXT_PREPPER: LazyLock<TextAnalyzer> = LazyLock::new(|| {
70 TextAnalyzer::builder(RawTokenizer::default())
71 .filter(LowerCaser)
72 .filter(AsciiFoldingFilter)
73 .build()
74});
75pub static NGRAM_TOKENIZER: LazyLock<TextAnalyzer> = LazyLock::new(|| {
77 TextAnalyzer::builder(NgramTokenizer::all_ngrams(3, 3).unwrap())
78 .filter(AlphaNumOnlyFilter)
79 .build()
80});
81
82fn tokenize_visitor(tokenizer: &TextAnalyzer, text: &str, mut visitor: impl FnMut(&String)) {
84 let mut prepper = TEXT_PREPPER.clone();
92 let mut tokenizer = tokenizer.clone();
93 let mut raw_stream = prepper.token_stream(text);
94 while raw_stream.advance() {
95 let mut token_stream = tokenizer.token_stream(&raw_stream.token().text);
96 while token_stream.advance() {
97 visitor(&token_stream.token().text);
98 }
99 }
100}
101
102const ALPHA_SPAN: usize = 37;
103const MAX_TOKEN: usize = ALPHA_SPAN.pow(2) + ALPHA_SPAN;
104const MIN_TOKEN: usize = 0;
105const NGRAM_N: usize = 3;
106
107fn ngram_to_token(ngram: &str, ngram_length: usize) -> u32 {
128 let mut token = 0;
129 for (idx, byte) in ngram.bytes().enumerate() {
131 let pos = if byte <= b'9' {
132 byte - b'0'
133 } else if byte <= b'z' {
134 byte - b'a' + 10
135 } else {
136 unreachable!()
137 } + 1;
138 debug_assert!(pos < ALPHA_SPAN as u8);
139 let mult = ALPHA_SPAN.pow(ngram_length as u32 - idx as u32 - 1) as u32;
140 token += pos as u32 * mult;
141 }
142 token
143}
144
145#[derive(Serialize)]
147struct NGramStatistics {
148 num_ngrams: usize,
149}
150
151#[derive(Debug)]
153pub struct NGramPostingList {
154 bitmap: RoaringTreemap,
155}
156
157impl DeepSizeOf for NGramPostingList {
158 fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
159 self.bitmap.serialized_size()
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct NGramPostingListKey {
166 pub row_offset: u32,
167}
168
169impl CacheKey for NGramPostingListKey {
170 type ValueType = NGramPostingList;
171
172 fn key(&self) -> std::borrow::Cow<'_, str> {
173 format!("posting-list-{}", self.row_offset).into()
174 }
175
176 fn type_name() -> &'static str {
177 "NGramPostingList"
178 }
179}
180
181impl NGramPostingList {
182 fn try_from_batch(
183 batch: RecordBatch,
184 frag_reuse_index: Option<Arc<FragReuseIndex>>,
185 ) -> Result<Self> {
186 let bitmap_bytes = batch.column(0).as_binary::<i32>().value(0);
187 let mut bitmap = RoaringTreemap::deserialize_from(bitmap_bytes)
188 .map_err(|e| Error::internal(format!("Error deserializing ngram list: {}", e)))?;
189 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
190 bitmap = frag_reuse_index_ref.remap_row_ids_roaring_tree_map(&bitmap);
191 }
192 Ok(Self { bitmap })
193 }
194
195 fn intersect<'a>(lists: impl IntoIterator<Item = &'a Self>) -> RoaringTreemap {
196 let mut iter = lists.into_iter();
197 let mut result = iter
198 .next()
199 .map(|list| list.bitmap.clone())
200 .unwrap_or_default();
201 for list in iter {
202 result &= &list.bitmap;
203 }
204 result
205 }
206}
207
208struct NGramPostingListReader {
210 reader: Arc<dyn IndexReader>,
211 frag_reuse_index: Option<Arc<FragReuseIndex>>,
212 index_cache: WeakLanceCache,
213}
214
215impl DeepSizeOf for NGramPostingListReader {
216 fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
217 0
218 }
219}
220
221impl std::fmt::Debug for NGramPostingListReader {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 f.debug_struct("NGramListReader").finish()
224 }
225}
226
227impl NGramPostingListReader {
228 #[instrument(level = "debug", skip(self, metrics))]
229 pub async fn ngram_list(
230 &self,
231 row_offset: u32,
232 metrics: &dyn MetricsCollector,
233 ) -> Result<Arc<NGramPostingList>> {
234 self.index_cache.get_or_insert_with_key(NGramPostingListKey { row_offset }, || async move {
235 metrics.record_part_load();
236 tracing::info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="ngram", part_id=row_offset);
237 let batch = self
238 .reader
239 .read_range(
240 row_offset as usize..row_offset as usize + 1,
241 Some(&[POSTING_LIST_COL]),
242 )
243 .await?;
244 NGramPostingList::try_from_batch(batch, self.frag_reuse_index.clone())
245 }).await
246 }
247}
248
249pub struct NGramIndex {
264 tokens: HashMap<u32, u32>,
266 list_reader: Arc<NGramPostingListReader>,
268 tokenizer: TextAnalyzer,
273 io_parallelism: usize,
274 store: Arc<dyn IndexStore>,
276}
277
278impl std::fmt::Debug for NGramIndex {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 f.debug_struct("NGramIndex")
281 .field("tokens", &self.tokens)
282 .field("list_reader", &self.list_reader)
283 .finish()
284 }
285}
286
287impl DeepSizeOf for NGramIndex {
288 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
289 self.tokens.deep_size_of_children(context)
290 }
291}
292
293impl NGramIndex {
294 async fn from_store(
295 store: Arc<dyn IndexStore>,
296 frag_reuse_index: Option<Arc<FragReuseIndex>>,
297 index_cache: &LanceCache,
298 ) -> Result<Self> {
299 let tokens = store.open_index_file(POSTINGS_FILENAME).await?;
300 let tokens = tokens
301 .read_range(0..tokens.num_rows(), Some(&[TOKENS_COL]))
302 .await?;
303
304 let tokens_map = HashMap::from_iter(
305 tokens
306 .column(0)
307 .as_primitive::<UInt32Type>()
308 .values()
309 .iter()
310 .copied()
311 .enumerate()
312 .map(|(idx, token)| (token, idx as u32)),
313 );
314
315 let posting_reader = Arc::new(NGramPostingListReader {
316 reader: store.open_index_file(POSTINGS_FILENAME).await?,
317 frag_reuse_index,
318 index_cache: WeakLanceCache::from(index_cache),
319 });
320
321 Ok(Self {
322 io_parallelism: store.io_parallelism(),
323 tokens: tokens_map,
324 list_reader: posting_reader,
325 tokenizer: NGRAM_TOKENIZER.clone(),
326 store,
327 })
328 }
329
330 fn remap_batch(
331 &self,
332 batch: RecordBatch,
333 mapping: &HashMap<u64, Option<u64>>,
334 ) -> Result<RecordBatch> {
335 let posting_lists_array = batch
336 .column_by_name(POSTING_LIST_COL)
337 .expect_ok()?
338 .as_binary::<i32>();
339
340 let new_posting_lists = posting_lists_array
341 .iter()
342 .map(|posting_list| {
343 let posting_list = posting_list.unwrap();
344 let posting_list = RoaringTreemap::deserialize_from(posting_list)?;
345 let new_posting_list =
346 RoaringTreemap::from_iter(posting_list.into_iter().filter_map(|row_id| {
347 match mapping.get(&row_id) {
348 Some(Some(new_row_id)) => Some(*new_row_id),
349 Some(None) => None,
350 None => Some(row_id),
351 }
352 }));
353 let mut buf = Vec::with_capacity(new_posting_list.serialized_size());
354 new_posting_list.serialize_into(&mut buf)?;
355 Ok(buf)
356 })
357 .collect::<Result<Vec<_>>>()?;
358
359 let new_posting_lists_array = BinaryArray::from_iter_values(new_posting_lists);
360
361 Ok(RecordBatch::try_new(
362 POSTINGS_SCHEMA.clone(),
363 vec![
364 batch.column_by_name(TOKENS_COL).expect_ok()?.clone(),
365 Arc::new(new_posting_lists_array),
366 ],
367 )?)
368 }
369
370 async fn load(
371 store: Arc<dyn IndexStore>,
372 frag_reuse_index: Option<Arc<FragReuseIndex>>,
373 index_cache: &LanceCache,
374 ) -> Result<Arc<Self>>
375 where
376 Self: Sized,
377 {
378 Ok(Arc::new(
379 Self::from_store(store, frag_reuse_index, index_cache).await?,
380 ))
381 }
382}
383
384#[async_trait]
385impl Index for NGramIndex {
386 fn as_any(&self) -> &dyn Any {
387 self
388 }
389
390 fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
391 self
392 }
393
394 fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn VectorIndex>> {
395 Err(Error::invalid_input_source(
396 "NGramIndex is not a vector index".into(),
397 ))
398 }
399
400 fn statistics(&self) -> Result<serde_json::Value> {
401 let ngram_stats = NGramStatistics {
402 num_ngrams: self.tokens.len(),
403 };
404 serde_json::to_value(ngram_stats)
405 .map_err(|e| Error::internal(format!("Error serializing statistics: {}", e)))
406 }
407
408 async fn prewarm(&self) -> Result<()> {
409 Ok(())
411 }
412
413 fn index_type(&self) -> IndexType {
414 IndexType::NGram
415 }
416
417 async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
418 let mut frag_ids = RoaringBitmap::new();
419 for row_offset in self.tokens.values() {
420 let list = self
421 .list_reader
422 .ngram_list(*row_offset, &NoOpMetricsCollector)
423 .await?;
424 frag_ids.extend(
425 list.bitmap
426 .iter()
427 .map(|row_addr| RowAddress::from(row_addr).fragment_id()),
428 );
429 }
430 Ok(frag_ids)
431 }
432}
433
434#[async_trait]
435impl ScalarIndex for NGramIndex {
436 async fn search(
437 &self,
438 query: &dyn AnyQuery,
439 metrics: &dyn MetricsCollector,
440 ) -> Result<SearchResult> {
441 let query = query
442 .as_any()
443 .downcast_ref::<TextQuery>()
444 .ok_or_else(|| Error::invalid_input_source("Query is not a TextQuery".into()))?;
445 match query {
446 TextQuery::StringContains(substr) => {
447 if substr.len() < NGRAM_N {
448 return Ok(SearchResult::at_least(RowAddrTreeMap::new()));
450 }
451
452 let mut row_offsets = Vec::with_capacity(substr.len() * 3);
453 let mut missing = false;
454 tokenize_visitor(&self.tokenizer, substr, |ngram| {
455 let token = ngram_to_token(ngram, NGRAM_N);
456 if let Some(row_offset) = self.tokens.get(&token) {
457 row_offsets.push(*row_offset);
458 } else {
459 missing = true;
460 }
461 });
462 if missing {
464 return Ok(SearchResult::exact(RowAddrTreeMap::new()));
465 }
466 let posting_lists = futures::stream::iter(
467 row_offsets
468 .into_iter()
469 .map(|row_offset| self.list_reader.ngram_list(row_offset, metrics)),
470 )
471 .buffer_unordered(self.io_parallelism)
472 .try_collect::<Vec<_>>()
473 .await?;
474 metrics.record_comparisons(posting_lists.len());
475 let list_refs = posting_lists.iter().map(|list| list.as_ref());
476 let row_ids = NGramPostingList::intersect(list_refs);
477 Ok(SearchResult::at_most(RowAddrTreeMap::from(row_ids)))
478 }
479 }
480 }
481
482 fn can_remap(&self) -> bool {
483 true
484 }
485
486 async fn remap(
487 &self,
488 mapping: &HashMap<u64, Option<u64>>,
489 dest_store: &dyn IndexStore,
490 ) -> Result<CreatedIndex> {
491 let reader = self.store.open_index_file(POSTINGS_FILENAME).await?;
492 let mut writer = dest_store
493 .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone())
494 .await?;
495
496 let mut offset = 0;
497 let num_rows = reader.num_rows();
498 const BATCH_SIZE: usize = 64;
499 while offset < num_rows {
500 let batch_size = BATCH_SIZE.min(num_rows - offset);
501 let batch = reader.read_range(offset..offset + batch_size, None).await?;
502 let batch = self.remap_batch(batch, mapping)?;
503 writer.write_record_batch(batch).await?;
504 offset += BATCH_SIZE;
505 }
506
507 writer.finish().await?;
508
509 Ok(CreatedIndex {
510 index_details: prost_types::Any::from_msg(&pbold::NGramIndexDetails::default())
511 .unwrap(),
512 index_version: NGRAM_INDEX_VERSION,
513 files: Some(dest_store.list_files_with_sizes().await?),
514 })
515 }
516
517 async fn update(
518 &self,
519 new_data: SendableRecordBatchStream,
520 dest_store: &dyn IndexStore,
521 _old_data_filter: Option<super::OldIndexDataFilter>,
522 ) -> Result<CreatedIndex> {
523 let mut builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default())?;
524 let spill_files = builder.train(new_data).await?;
525
526 builder
527 .write_index(dest_store, spill_files, Some(self.store.clone()))
528 .await?;
529
530 Ok(CreatedIndex {
531 index_details: prost_types::Any::from_msg(&pbold::NGramIndexDetails::default())
532 .unwrap(),
533 index_version: NGRAM_INDEX_VERSION,
534 files: Some(dest_store.list_files_with_sizes().await?),
535 })
536 }
537
538 fn update_criteria(&self) -> UpdateCriteria {
539 UpdateCriteria::only_new_data(TrainingCriteria::new(TrainingOrdering::None).with_row_id())
540 }
541
542 fn derive_index_params(&self) -> Result<ScalarIndexParams> {
543 Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::NGram))
544 }
545}
546
547#[derive(Debug, Clone)]
548pub struct NGramIndexBuilderOptions {
549 tokens_per_spill: usize,
550}
551
552static DEFAULT_TOKENS_PER_SPILL: LazyLock<usize> = LazyLock::new(|| {
554 std::env::var("LANCE_NGRAM_TOKENS_PER_SPILL")
555 .unwrap_or_else(|_| "1000000000".to_string())
556 .parse()
557 .expect("failed to parse LANCE_NGRAM_TOKENS_PER_SPILL")
558});
559static DEFAULT_NUM_PARTITIONS: LazyLock<usize> = LazyLock::new(|| {
565 std::env::var("LANCE_NGRAM_NUM_PARTITIONS")
566 .map(|s| s.parse().expect("failed to parse LANCE_NGRAM_PARALLELISM"))
567 .unwrap_or((get_num_compute_intensive_cpus() * 4).max(128))
568});
569static DEFAULT_TOKENIZE_PARALLELISM: LazyLock<usize> = LazyLock::new(|| {
571 std::env::var("LANCE_NGRAM_TOKENIZE_PARALLELISM")
572 .map(|s| {
573 s.parse()
574 .expect("failed to parse LANCE_NGRAM_TOKENIZE_PARALLELISM")
575 })
576 .unwrap_or(8)
577});
578
579impl Default for NGramIndexBuilderOptions {
580 fn default() -> Self {
581 Self {
582 tokens_per_spill: *DEFAULT_TOKENS_PER_SPILL,
583 }
584 }
585}
586
587struct NGramIndexSpillState {
591 tokens: UInt32Array,
592 bitmaps: Vec<RoaringTreemap>,
593}
594
595impl NGramIndexSpillState {
596 fn try_from_batch(batch: RecordBatch) -> Result<Self> {
597 let tokens = batch
598 .column_by_name(TOKENS_COL)
599 .expect_ok()?
600 .as_primitive::<UInt32Type>()
601 .clone();
602 let postings = batch
603 .column_by_name(POSTING_LIST_COL)
604 .expect_ok()?
605 .as_binary::<i32>();
606
607 let bitmaps = postings
608 .into_iter()
609 .map(|bytes| {
610 RoaringTreemap::deserialize_from(bytes.expect_ok()?)
611 .map_err(|e| Error::internal(format!("Error deserializing ngram list: {}", e)))
612 })
613 .collect::<Result<Vec<_>>>()?;
614
615 Ok(Self { tokens, bitmaps })
616 }
617
618 fn try_into_batch(self) -> Result<RecordBatch> {
619 let bitmap_array = BinaryArray::from_iter_values(self.bitmaps.into_iter().map(|bitmap| {
620 let mut buf = Vec::with_capacity(bitmap.serialized_size());
621 bitmap.serialize_into(&mut buf).unwrap();
622 buf
623 }));
624 Ok(RecordBatch::try_new(
625 POSTINGS_SCHEMA.clone(),
626 vec![Arc::new(self.tokens), Arc::new(bitmap_array)],
627 )?)
628 }
629}
630
631struct NGramIndexBuildState {
634 tokens_map: BTreeMap<u32, RoaringTreemap>,
635}
636
637impl NGramIndexBuildState {
638 fn starting() -> Self {
639 Self {
640 tokens_map: BTreeMap::new(),
641 }
642 }
643
644 fn take(&mut self) -> Self {
645 let mut taken = Self::starting();
646 std::mem::swap(&mut self.tokens_map, &mut taken.tokens_map);
647 taken
648 }
649
650 fn into_spill(self) -> NGramIndexSpillState {
651 let tokens = UInt32Array::from_iter_values(self.tokens_map.keys().copied());
653 let bitmaps = Vec::from_iter(self.tokens_map.into_values());
654
655 NGramIndexSpillState { bitmaps, tokens }
656 }
657}
658
659pub struct NGramIndexBuilder {
672 tokenizer: TextAnalyzer,
673 options: NGramIndexBuilderOptions,
674 tmpdir: Arc<TempDir>,
675 spill_store: Arc<dyn IndexStore>,
676
677 tokens_seen: usize,
678 worker_number: usize,
679 has_flushed: bool,
680
681 state: NGramIndexBuildState,
682}
683
684impl NGramIndexBuilder {
685 pub fn try_new(options: NGramIndexBuilderOptions) -> Result<Self> {
686 Self::from_state(NGramIndexBuildState::starting(), options)
687 }
688
689 fn clone_worker(&self, worker_number: usize) -> Self {
690 let mut bitmaps = Vec::with_capacity(36 * 36 * 36 + 1);
691 bitmaps.push(RoaringTreemap::new());
693 Self {
694 tokenizer: self.tokenizer.clone(),
695 state: NGramIndexBuildState::starting(),
696 tmpdir: self.tmpdir.clone(),
697 spill_store: self.spill_store.clone(),
698 options: self.options.clone(),
699 tokens_seen: 0,
700 worker_number,
701 has_flushed: false,
702 }
703 }
704
705 fn from_state(state: NGramIndexBuildState, options: NGramIndexBuilderOptions) -> Result<Self> {
706 let tokenizer = NGRAM_TOKENIZER.clone();
707
708 let tmpdir = Arc::new(TempDir::default());
709 let spill_store = Arc::new(LanceIndexStore::new(
710 Arc::new(ObjectStore::local()),
711 tmpdir.obj_path(),
712 Arc::new(LanceCache::no_cache()),
713 ));
714
715 Ok(Self {
716 tokenizer,
717 state,
718 tmpdir,
719 spill_store,
720 options,
721 tokens_seen: 0,
722 worker_number: 0,
723 has_flushed: false,
724 })
725 }
726
727 fn validate_schema(schema: &Schema) -> Result<()> {
728 if schema.fields().len() != 2 {
729 return Err(Error::invalid_input_source(
730 "Ngram index schema must have exactly two fields".into(),
731 ));
732 }
733 let values_field = schema.field_with_name(VALUE_COLUMN_NAME)?;
734 if *values_field.data_type() != DataType::Utf8
735 && *values_field.data_type() != DataType::LargeUtf8
736 {
737 return Err(Error::invalid_input_source(
738 "First field in ngram index schema must be of type Utf8/LargeUtf8".into(),
739 ));
740 }
741 let row_id_field = schema.field_with_name(ROW_ID)?;
742 if *row_id_field.data_type() != DataType::UInt64 {
743 return Err(Error::invalid_input_source(
744 "Second field in ngram index schema must be of type UInt64".into(),
745 ));
746 }
747 Ok(())
748 }
749
750 async fn process_batch(&mut self, tokens_and_ids: Vec<(u32, u64)>) -> Result<()> {
751 let mut tokens_seen = 0;
752 for (token, row_id) in tokens_and_ids {
753 tokens_seen += 1;
754 self.state
759 .tokens_map
760 .entry(token)
761 .or_default()
762 .insert(row_id);
763 }
764 self.tokens_seen += tokens_seen;
765 if self.tokens_seen >= self.options.tokens_per_spill {
766 let state = self.state.take();
767 self.flush(state).await?;
768 }
769 Ok(())
770 }
771
772 fn spill_filename(id: usize) -> String {
773 format!("spill-{}.lance", id)
774 }
775
776 fn tmp_spill_filename(id: usize) -> String {
777 format!("spill-{}.lance.tmp", id)
778 }
779
780 async fn flush(&mut self, state: NGramIndexBuildState) -> Result<bool> {
781 if self.tokens_seen == 0 {
782 assert!(state.tokens_map.is_empty());
783 return Ok(self.has_flushed);
784 }
785 self.tokens_seen = 0;
786 let spill_state = state.into_spill();
787 let flush_start = Instant::now();
788 debug_assert_ne!(self.worker_number, 0);
790 if self.has_flushed {
791 info!("Merging flush for worker {}", self.worker_number);
792 let mut writer = self
794 .spill_store
795 .new_index_file(
796 &Self::tmp_spill_filename(self.worker_number),
797 POSTINGS_SCHEMA.clone(),
798 )
799 .await?;
800
801 let left_stream = stream::once(std::future::ready(Ok(spill_state)));
802 let right_stream =
803 Self::stream_spill(self.spill_store.clone(), self.worker_number).await?;
804 Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?;
805 drop(writer);
806 self.spill_store
807 .rename_index_file(
808 &Self::tmp_spill_filename(self.worker_number),
809 &Self::spill_filename(self.worker_number),
810 )
811 .await?;
812 } else {
813 info!("Initial flush for worker {}", self.worker_number);
815 self.has_flushed = true;
816 let writer = self
817 .spill_store
818 .new_index_file(
819 &Self::spill_filename(self.worker_number),
820 POSTINGS_SCHEMA.clone(),
821 )
822 .await?;
823 self.write(writer, spill_state).await?;
824 }
825 let flush_time = flush_start.elapsed();
826 info!(
827 "Flushed worker {} in {}ms",
828 self.worker_number,
829 flush_time.as_millis()
830 );
831 Ok(true)
832 }
833
834 fn tokenize_and_partition(
835 tokenizer: &TextAnalyzer,
836 batch: RecordBatch,
837 num_workers: usize,
838 ) -> Result<Vec<Vec<(u32, u64)>>> {
839 let text_iter = iter_str_array(batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?);
840 let row_id_col = batch
841 .column_by_name(ROW_ID)
842 .expect_ok()?
843 .as_primitive::<UInt64Type>();
844 let mut partitions = vec![Vec::with_capacity(batch.num_rows() * 1000); num_workers];
846 let divisor = (MAX_TOKEN - MIN_TOKEN) / num_workers;
847 for (text, row_id) in text_iter.zip(row_id_col.values()) {
848 if let Some(text) = text {
849 tokenize_visitor(tokenizer, text, |token| {
850 let token = ngram_to_token(token, NGRAM_N);
851 let partition_id = (token as usize).saturating_sub(MIN_TOKEN) / divisor;
852 partitions[partition_id % num_workers].push((token, *row_id));
853 });
854 } else {
855 partitions[0].push((0, *row_id));
856 }
857 }
858 Ok(partitions)
859 }
860
861 pub async fn train(&mut self, data: SendableRecordBatchStream) -> Result<Vec<usize>> {
862 let schema = data.schema();
863 Self::validate_schema(schema.as_ref())?;
864
865 let num_workers = *DEFAULT_NUM_PARTITIONS;
866 let mut senders = Vec::with_capacity(num_workers);
867 let mut builders = Vec::with_capacity(num_workers);
868 for worker_idx in 0..num_workers {
869 let (send, mut recv) = tokio::sync::mpsc::channel(2);
870 senders.push(send);
871
872 let mut builder = self.clone_worker(worker_idx + 1);
873 let future = tokio::spawn(async move {
874 while let Some(partition) = recv.recv().await {
875 builder.process_batch(partition).await?;
876 }
877 Result::Ok(builder)
878 });
879 builders.push(future);
880 }
881
882 let mut partitions_stream = data
883 .and_then(|batch| {
884 let tokenizer = self.tokenizer.clone();
885 std::future::ready(Ok(tokio::task::spawn(async move {
886 Ok(Self::tokenize_and_partition(
887 &tokenizer,
888 batch,
889 num_workers,
890 )?)
891 })
892 .map(|res| res.unwrap())))
893 })
894 .try_buffer_unordered(*DEFAULT_TOKENIZE_PARALLELISM);
895
896 while let Some(partitions) = partitions_stream.try_next().await? {
897 for (part_idx, partition) in partitions.into_iter().enumerate() {
898 senders[part_idx].send(partition).await.unwrap();
899 }
900 }
901
902 std::mem::drop(senders);
903 let builders = futures::future::try_join_all(builders).await?;
904
905 let mut to_spill = Vec::with_capacity(builders.len());
909
910 for builder in builders {
911 let mut builder = builder?;
912 let state = builder.state.take();
913 if builder.flush(state).await? {
914 to_spill.push(builder.worker_number);
915 }
916 }
917
918 Ok(to_spill)
919 }
920
921 async fn write(
922 &mut self,
923 mut writer: Box<dyn IndexWriter>,
924 state: NGramIndexSpillState,
925 ) -> Result<()> {
926 writer.write_record_batch(state.try_into_batch()?).await?;
927 writer.finish().await?;
928
929 Ok(())
930 }
931
932 async fn stream_spill_reader(
933 reader: Arc<dyn IndexReader>,
934 ) -> Result<impl Stream<Item = Result<NGramIndexSpillState>>> {
935 let num_rows = reader.num_rows();
936
937 Ok(stream::try_unfold(0, move |offset| {
938 let reader = reader.clone();
939 async move {
940 let batch_size = std::cmp::min(num_rows - offset, 64);
943 if batch_size == 0 {
944 return Ok(None);
945 }
946 let batch = reader.read_range(offset..offset + batch_size, None).await?;
947 let state = NGramIndexSpillState::try_from_batch(batch)?;
948 let new_offset = offset + batch_size;
949 Ok(Some((state, new_offset)))
950 }
951 .boxed()
952 }))
953 }
954
955 async fn stream_spill(
956 spill_store: Arc<dyn IndexStore>,
957 id: usize,
958 ) -> Result<impl Stream<Item = Result<NGramIndexSpillState>>> {
959 let reader = spill_store
960 .open_index_file(&Self::spill_filename(id))
961 .await?;
962 Self::stream_spill_reader(reader).await
963 }
964
965 fn merge_spill_states(
966 left_opt: &mut Option<NGramIndexSpillState>,
967 right_opt: &mut Option<NGramIndexSpillState>,
968 ) -> NGramIndexSpillState {
969 let left = left_opt.take().unwrap();
970 let right = right_opt.take().unwrap();
971
972 let item_capacity = left.tokens.len() + right.tokens.len();
973 let mut merged_tokens = UInt32Builder::with_capacity(item_capacity);
974 let mut merged_bitmaps = Vec::with_capacity(left.bitmaps.len() + right.bitmaps.len());
975
976 let mut left_tokens = left.tokens.values().iter().copied();
977 let mut left_bitmaps = left.bitmaps.into_iter();
978 let mut right_tokens = right.tokens.values().iter().copied();
979 let mut right_bitmaps = right.bitmaps.into_iter();
980
981 let mut left_token = left_tokens.next();
982 let mut left_bitmap = left_bitmaps.next();
983 let mut right_token = right_tokens.next();
984 let mut right_bitmap = right_bitmaps.next();
985
986 while left_token.is_some() && right_token.is_some() {
987 let left_token_val = left_token.unwrap();
988 let right_token_val = right_token.unwrap();
989 match left_token_val.cmp(&right_token_val) {
990 std::cmp::Ordering::Less => {
991 merged_tokens.append_value(left_token_val);
992 merged_bitmaps.push(left_bitmap.unwrap());
993 left_token = left_tokens.next();
994 left_bitmap = left_bitmaps.next();
995 }
996 std::cmp::Ordering::Greater => {
997 merged_tokens.append_value(right_token_val);
998 merged_bitmaps.push(right_bitmap.unwrap());
999 right_token = right_tokens.next();
1000 right_bitmap = right_bitmaps.next();
1001 }
1002 std::cmp::Ordering::Equal => {
1003 merged_tokens.append_value(left_token_val);
1004 merged_bitmaps.push(left_bitmap.unwrap() | &right_bitmap.unwrap());
1005 left_token = left_tokens.next();
1006 left_bitmap = left_bitmaps.next();
1007 right_token = right_tokens.next();
1008 right_bitmap = right_bitmaps.next();
1009 }
1010 }
1011 }
1012
1013 let collect_remaining = |cur_token, tokens, cur_bitmap, bitmaps| {
1014 let tokens = UInt32Array::from_iter_values(once(cur_token).chain(tokens));
1015 let bitmaps = once(cur_bitmap).chain(bitmaps).collect::<Vec<_>>();
1016 NGramIndexSpillState { tokens, bitmaps }
1017 };
1018
1019 if let Some(left_token) = left_token {
1020 *left_opt = Some(collect_remaining(
1021 left_token,
1022 left_tokens,
1023 left_bitmap.unwrap(),
1024 left_bitmaps,
1025 ));
1026 } else {
1027 *left_opt = None;
1028 }
1029 if let Some(right_token) = right_token {
1030 *right_opt = Some(collect_remaining(
1031 right_token,
1032 right_tokens,
1033 right_bitmap.unwrap(),
1034 right_bitmaps,
1035 ));
1036 } else {
1037 *right_opt = None;
1038 }
1039
1040 NGramIndexSpillState {
1041 tokens: merged_tokens.finish(),
1042 bitmaps: merged_bitmaps,
1043 }
1044 }
1045
1046 async fn merge_spill_streams(
1047 mut left_stream: impl Stream<Item = Result<NGramIndexSpillState>> + Unpin,
1048 mut right_stream: impl Stream<Item = Result<NGramIndexSpillState>> + Unpin,
1049 writer: &mut dyn IndexWriter,
1050 ) -> Result<()> {
1051 let mut left_state = left_stream.try_next().await?;
1052 let mut right_state = right_stream.try_next().await?;
1053
1054 while left_state.is_some() || right_state.is_some() {
1055 if left_state.is_none() {
1056 let state = right_state.take().expect_ok()?;
1058 writer.write_record_batch(state.try_into_batch()?).await?;
1059 while let Some(state) = right_stream.try_next().await? {
1060 writer.write_record_batch(state.try_into_batch()?).await?;
1061 }
1062 } else if right_state.is_none() {
1063 let state = left_state.take().expect_ok()?;
1065 writer.write_record_batch(state.try_into_batch()?).await?;
1066 while let Some(state) = left_stream.try_next().await? {
1067 writer.write_record_batch(state.try_into_batch()?).await?;
1068 }
1069 } else {
1070 let merged = Self::merge_spill_states(&mut left_state, &mut right_state);
1072 writer.write_record_batch(merged.try_into_batch()?).await?;
1073 if left_state.is_none() {
1074 left_state = left_stream.try_next().await?;
1075 }
1076 if right_state.is_none() {
1077 right_state = right_stream.try_next().await?;
1078 }
1079 }
1080 }
1081
1082 writer.finish().await
1083 }
1084
1085 async fn merge_spill_files(
1086 spill_store: Arc<dyn IndexStore>,
1087 index_of_left: usize,
1088 index_of_right: usize,
1089 output_index: usize,
1090 ) -> Result<()> {
1091 info!(
1093 "Merge spill files {} and {} into {}",
1094 index_of_left, index_of_right, output_index
1095 );
1096
1097 let mut writer = spill_store
1098 .new_index_file(&Self::spill_filename(output_index), POSTINGS_SCHEMA.clone())
1099 .await?;
1100
1101 let (left_stream, right_stream) = futures::try_join!(
1102 Self::stream_spill(spill_store.clone(), index_of_left),
1103 Self::stream_spill(spill_store.clone(), index_of_right)
1104 )?;
1105
1106 Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?;
1107
1108 spill_store
1109 .delete_index_file(&Self::spill_filename(index_of_left))
1110 .await?;
1111 spill_store
1112 .delete_index_file(&Self::spill_filename(index_of_right))
1113 .await?;
1114
1115 Ok(())
1116 }
1117
1118 async fn merge_spills(&mut self, mut spill_files: Vec<usize>) -> Result<usize> {
1125 info!(
1126 "Merging {} index files into one combined index",
1127 spill_files.len()
1128 );
1129
1130 let mut spill_counter = spill_files.iter().max().expect_ok()? + 1;
1131 while spill_files.len() > 1 {
1132 let mut new_spills = Vec::with_capacity(spill_files.len() / 2);
1133 while spill_files.len() >= 2 {
1134 let left = spill_files.pop().expect_ok()?;
1135 let right = spill_files.pop().expect_ok()?;
1136 new_spills.push(tokio::spawn(Self::merge_spill_files(
1137 self.spill_store.clone(),
1138 left,
1139 right,
1140 spill_counter + new_spills.len(),
1141 )));
1142 }
1143 for i in 0..new_spills.len() {
1144 spill_files.push(spill_counter + i);
1145 }
1146 spill_counter += new_spills.len();
1147 futures::future::try_join_all(new_spills).await?;
1148 }
1149
1150 spill_files.pop().expect_ok()
1151 }
1152
1153 async fn merge_old_index(
1154 &mut self,
1155 new_data_num: usize,
1156 old_index: Arc<dyn IndexStore>,
1157 ) -> Result<usize> {
1158 info!("Merging old index into new index");
1159 let final_num = new_data_num + 1;
1160
1161 let mut writer = self
1162 .spill_store
1163 .new_index_file(&Self::spill_filename(final_num), POSTINGS_SCHEMA.clone())
1164 .await?;
1165
1166 let left_stream = Self::stream_spill(self.spill_store.clone(), new_data_num).await?;
1167 let old_reader = old_index.open_index_file(POSTINGS_FILENAME).await?;
1168 let right_stream = Self::stream_spill_reader(old_reader).await?;
1169
1170 Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?;
1171
1172 self.spill_store
1173 .delete_index_file(&Self::spill_filename(new_data_num))
1174 .await?;
1175
1176 Ok(final_num)
1177 }
1178
1179 pub async fn write_index(
1180 mut self,
1181 store: &dyn IndexStore,
1182 spill_files: Vec<usize>,
1183 old_index: Option<Arc<dyn IndexStore>>,
1184 ) -> Result<()> {
1185 let mut writer = store
1186 .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone())
1187 .await?;
1188
1189 if spill_files.is_empty() {
1190 if let Some(old_index) = old_index {
1191 old_index.copy_index_file(POSTINGS_FILENAME, store).await?;
1193 } else {
1194 let mut writer = store
1196 .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone())
1197 .await?;
1198 writer.finish().await?;
1199 }
1200 return Ok(());
1201 }
1202
1203 let mut index_to_copy = self.merge_spills(spill_files).await?;
1204
1205 if let Some(old_index) = old_index {
1206 index_to_copy = self.merge_old_index(index_to_copy, old_index).await?;
1207 }
1208
1209 let reader = self
1210 .spill_store
1211 .open_index_file(&Self::spill_filename(index_to_copy))
1212 .await?;
1213
1214 let num_rows = reader.num_rows();
1215 let mut offset = 0;
1216
1217 while offset < num_rows {
1218 let batch_size = std::cmp::min(num_rows - offset, 64);
1219 let batch = reader.read_range(offset..offset + batch_size, None).await?;
1220 writer.write_record_batch(batch).await?;
1221 offset += batch_size;
1222 }
1223
1224 writer.finish().await
1225 }
1226}
1227
1228#[derive(Debug, Default)]
1229pub struct NGramIndexPlugin;
1230
1231impl NGramIndexPlugin {
1232 pub async fn train_ngram_index(
1233 batches_source: SendableRecordBatchStream,
1234 index_store: &dyn IndexStore,
1235 ) -> Result<()> {
1236 let mut builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default())?;
1237
1238 let spill_files = builder.train(batches_source).await?;
1239
1240 builder.write_index(index_store, spill_files, None).await
1241 }
1242}
1243
1244#[async_trait]
1245impl ScalarIndexPlugin for NGramIndexPlugin {
1246 fn name(&self) -> &str {
1247 "NGram"
1248 }
1249
1250 fn new_training_request(
1251 &self,
1252 _params: &str,
1253 field: &Field,
1254 ) -> Result<Box<dyn TrainingRequest>> {
1255 if !matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
1256 return Err(Error::invalid_input_source(format!(
1257 "A ngram index can only be created on a Utf8 or LargeUtf8 field. Column has type {:?}",
1258 field.data_type()
1259 )
1260 .into()));
1261 }
1262 Ok(Box::new(DefaultTrainingRequest::new(
1263 TrainingCriteria::new(TrainingOrdering::None).with_row_id(),
1264 )))
1265 }
1266
1267 fn provides_exact_answer(&self) -> bool {
1268 false
1269 }
1270
1271 fn version(&self) -> u32 {
1272 NGRAM_INDEX_VERSION
1273 }
1274
1275 fn new_query_parser(
1276 &self,
1277 index_name: String,
1278 _index_details: &prost_types::Any,
1279 ) -> Option<Box<dyn ScalarQueryParser>> {
1280 Some(Box::new(TextQueryParser::new(index_name, true)))
1281 }
1282
1283 async fn train_index(
1284 &self,
1285 data: SendableRecordBatchStream,
1286 index_store: &dyn IndexStore,
1287 _request: Box<dyn TrainingRequest>,
1288 fragment_ids: Option<Vec<u32>>,
1289 _progress: Arc<dyn crate::progress::IndexBuildProgress>,
1290 ) -> Result<CreatedIndex> {
1291 if fragment_ids.is_some() {
1292 return Err(Error::invalid_input_source(
1293 "NGram index does not support fragment training".into(),
1294 ));
1295 }
1296
1297 Self::train_ngram_index(data, index_store).await?;
1298 Ok(CreatedIndex {
1299 index_details: prost_types::Any::from_msg(&pbold::NGramIndexDetails::default())
1300 .unwrap(),
1301 index_version: NGRAM_INDEX_VERSION,
1302 files: Some(index_store.list_files_with_sizes().await?),
1303 })
1304 }
1305
1306 async fn load_index(
1307 &self,
1308 index_store: Arc<dyn IndexStore>,
1309 _index_details: &prost_types::Any,
1310 frag_reuse_index: Option<Arc<FragReuseIndex>>,
1311 cache: &LanceCache,
1312 ) -> Result<Arc<dyn ScalarIndex>> {
1313 Ok(NGramIndex::load(index_store, frag_reuse_index, cache).await? as Arc<dyn ScalarIndex>)
1314 }
1315}
1316
1317#[cfg(test)]
1318mod tests {
1319 use std::{
1320 collections::{HashMap, HashSet},
1321 sync::Arc,
1322 };
1323
1324 use arrow::datatypes::UInt64Type;
1325 use arrow_array::{Array, RecordBatch, StringArray, UInt64Array};
1326 use arrow_schema::{DataType, Field, Schema};
1327 use datafusion::{
1328 execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter,
1329 };
1330 use datafusion_common::DataFusionError;
1331 use futures::{TryStreamExt, stream};
1332 use itertools::Itertools;
1333 use lance_core::{
1334 ROW_ID,
1335 cache::LanceCache,
1336 utils::{mask::RowAddrTreeMap, tempfile::TempDir},
1337 };
1338 use lance_datagen::{BatchCount, ByteCount, RowCount};
1339 use lance_io::object_store::ObjectStore;
1340 use lance_tokenizer::TextAnalyzer;
1341
1342 use crate::scalar::{
1343 ScalarIndex, SearchResult, TextQuery,
1344 lance_format::LanceIndexStore,
1345 ngram::{NGramIndex, NGramIndexBuilder, NGramIndexBuilderOptions},
1346 };
1347 use crate::{metrics::NoOpMetricsCollector, scalar::registry::VALUE_COLUMN_NAME};
1348
1349 use super::{NGRAM_TOKENIZER, ngram_to_token, tokenize_visitor};
1350
1351 fn collect_tokens(analyzer: &TextAnalyzer, text: &str) -> Vec<String> {
1352 let mut tokens = Vec::with_capacity(text.len() * 3);
1353 tokenize_visitor(analyzer, text, |token| tokens.push(token.to_owned()));
1354 tokens
1355 }
1356
1357 #[test]
1358 fn test_tokenizer() {
1359 let tokenizer = NGRAM_TOKENIZER.clone();
1360
1361 let tokens = collect_tokens(&tokenizer, "café");
1363 assert_eq!(
1364 tokens,
1365 vec!["caf", "afe"] );
1367
1368 let tokens = collect_tokens(&tokenizer, "a1b2");
1370 assert_eq!(tokens, vec!["a1b", "1b2"]);
1371
1372 let tokens = collect_tokens(&tokenizer, "abc👍b!c24");
1374
1375 assert_eq!(tokens, vec!["abc", "c24"]);
1376
1377 let tokens = collect_tokens(&tokenizer, "anstoß");
1378
1379 assert_eq!(tokens, vec!["ans", "nst", "sto", "tos", "oss"]);
1380
1381 let tokens = collect_tokens(&tokenizer, "ABC");
1383 assert_eq!(tokens, vec!["abc"]);
1384
1385 let tokens = collect_tokens(&tokenizer, "ababab");
1387 assert_eq!(
1390 tokens,
1391 vec!["aba", "bab", "aba", "bab"] );
1393 }
1394
1395 async fn do_train(
1396 mut builder: NGramIndexBuilder,
1397 data: SendableRecordBatchStream,
1398 ) -> (NGramIndex, Arc<TempDir>) {
1399 let spill_files = builder.train(data).await.unwrap();
1400
1401 let tmpdir = Arc::new(TempDir::default());
1402 let test_store = LanceIndexStore::new(
1403 Arc::new(ObjectStore::local()),
1404 tmpdir.obj_path(),
1405 Arc::new(LanceCache::no_cache()),
1406 );
1407
1408 builder
1409 .write_index(&test_store, spill_files, None)
1410 .await
1411 .unwrap();
1412
1413 (
1414 NGramIndex::from_store(Arc::new(test_store), None, &LanceCache::no_cache())
1415 .await
1416 .unwrap(),
1417 tmpdir,
1418 )
1419 }
1420
1421 async fn get_posting_list_for_trigram(index: &NGramIndex, trigram: &str) -> Vec<u64> {
1422 let token = ngram_to_token(trigram, 3);
1423 let row_offset = index.tokens[&token];
1424 let list = index
1425 .list_reader
1426 .ngram_list(row_offset, &NoOpMetricsCollector)
1427 .await
1428 .unwrap();
1429 list.bitmap.iter().sorted().collect()
1430 }
1431
1432 async fn get_null_posting_list(index: &NGramIndex) -> Vec<u64> {
1433 let row_offset = index.tokens[&0];
1434 let list = index
1435 .list_reader
1436 .ngram_list(row_offset, &NoOpMetricsCollector)
1437 .await
1438 .unwrap();
1439 list.bitmap.iter().sorted().collect()
1440 }
1441
1442 #[test_log::test(tokio::test)]
1443 async fn test_basic_ngram_index() {
1444 let data = StringArray::from_iter_values([
1445 "cat",
1446 "dog",
1447 "cat dog",
1448 "dog cat",
1449 "elephant",
1450 "mouse",
1451 "rhino",
1452 "giraffe",
1453 "rhinos nose",
1454 ]);
1455 let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64));
1456 let schema = Arc::new(Schema::new(vec![
1457 Field::new(VALUE_COLUMN_NAME, DataType::Utf8, false),
1458 Field::new(ROW_ID, DataType::UInt64, false),
1459 ]));
1460 let data =
1461 RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap();
1462 let data = Box::pin(RecordBatchStreamAdapter::new(
1463 schema,
1464 stream::once(std::future::ready(Ok(data))),
1465 ));
1466
1467 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1468
1469 let (index, _tmpdir) = do_train(builder, data).await;
1470 assert_eq!(index.tokens.len(), 21);
1471
1472 let res = index
1474 .search(
1475 &TextQuery::StringContains("cat".to_string()),
1476 &NoOpMetricsCollector,
1477 )
1478 .await
1479 .unwrap();
1480
1481 let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([0, 2, 3]));
1482
1483 assert_eq!(expected, res);
1484
1485 let res = index
1487 .search(
1488 &TextQuery::StringContains("nos nos".to_string()),
1489 &NoOpMetricsCollector,
1490 )
1491 .await
1492 .unwrap();
1493 let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8]));
1494 assert_eq!(expected, res);
1495
1496 let res = index
1498 .search(
1499 &TextQuery::StringContains("tdo".to_string()),
1500 &NoOpMetricsCollector,
1501 )
1502 .await
1503 .unwrap();
1504 let expected = SearchResult::exact(RowAddrTreeMap::new());
1505 assert_eq!(expected, res);
1506
1507 let res = index
1509 .search(
1510 &TextQuery::StringContains("inose".to_string()),
1511 &NoOpMetricsCollector,
1512 )
1513 .await
1514 .unwrap();
1515 let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8]));
1516 assert_eq!(expected, res);
1517
1518 let res = index
1520 .search(
1521 &TextQuery::StringContains("ab".to_string()),
1522 &NoOpMetricsCollector,
1523 )
1524 .await
1525 .unwrap();
1526 let expected = SearchResult::at_least(RowAddrTreeMap::new());
1527 assert_eq!(expected, res);
1528
1529 let res = index
1531 .search(
1532 &TextQuery::StringContains("no nos".to_string()),
1533 &NoOpMetricsCollector,
1534 )
1535 .await
1536 .unwrap();
1537 let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8]));
1538 assert_eq!(expected, res);
1539 }
1540
1541 fn test_data_schema() -> Arc<Schema> {
1542 Arc::new(Schema::new(vec![
1543 Field::new(VALUE_COLUMN_NAME, DataType::Utf8, true),
1544 Field::new(ROW_ID, DataType::UInt64, false),
1545 ]))
1546 }
1547
1548 fn simple_data_with_nulls() -> SendableRecordBatchStream {
1549 let data = StringArray::from_iter(&[Some("cat"), Some("dog"), None, None, Some("cat dog")]);
1550 let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64));
1551 let schema = test_data_schema();
1552 let data =
1553 RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap();
1554 Box::pin(RecordBatchStreamAdapter::new(
1555 schema,
1556 stream::once(std::future::ready(Ok(data))),
1557 ))
1558 }
1559
1560 #[test_log::test(tokio::test)]
1561 async fn test_ngram_nulls() {
1562 let data = simple_data_with_nulls();
1563
1564 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1565
1566 let (index, _tmpdir) = do_train(builder, data).await;
1567 assert_eq!(index.tokens.len(), 3);
1568
1569 let res = index
1570 .search(
1571 &TextQuery::StringContains("cat".to_string()),
1572 &NoOpMetricsCollector,
1573 )
1574 .await
1575 .unwrap();
1576 let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([0, 4]));
1577 assert_eq!(expected, res);
1578
1579 let null_posting_list = get_null_posting_list(&index).await;
1580 assert_eq!(null_posting_list, vec![2, 3]);
1581
1582 }
1584
1585 fn empty_data() -> SendableRecordBatchStream {
1586 Box::pin(RecordBatchStreamAdapter::new(
1587 test_data_schema(),
1588 stream::empty::<lance_core::error::DataFusionResult<RecordBatch>>(),
1589 ))
1590 }
1591
1592 #[test_log::test(tokio::test)]
1593 async fn test_train_empty() {
1594 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1595
1596 let (index, _tmpdir) = do_train(builder, empty_data()).await;
1597 assert_eq!(index.tokens.len(), 0);
1598 }
1599
1600 #[test_log::test(tokio::test)]
1601 async fn test_update_empty() {
1602 let data = simple_data_with_nulls();
1603
1604 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1605 let (index, _tmpdir) = do_train(builder, empty_data()).await;
1606
1607 let new_tmpdir = Arc::new(TempDir::default());
1608 let test_store = Arc::new(LanceIndexStore::new(
1609 Arc::new(ObjectStore::local()),
1610 new_tmpdir.obj_path(),
1611 Arc::new(LanceCache::no_cache()),
1612 ));
1613
1614 index.update(data, test_store.as_ref(), None).await.unwrap();
1615
1616 let index = NGramIndex::from_store(test_store, None, &LanceCache::no_cache())
1617 .await
1618 .unwrap();
1619 assert_eq!(index.tokens.len(), 3);
1620 }
1621
1622 async fn row_ids_in_index(index: &NGramIndex) -> Vec<u64> {
1623 let mut row_ids = HashSet::new();
1624 for row_offset in index.tokens.values() {
1625 let list = index
1626 .list_reader
1627 .ngram_list(*row_offset, &NoOpMetricsCollector)
1628 .await
1629 .unwrap();
1630 row_ids.extend(list.bitmap.iter());
1631 }
1632 row_ids.into_iter().sorted().collect()
1633 }
1634
1635 #[test_log::test(tokio::test)]
1636 async fn test_ngram_index_remap() {
1637 let data = simple_data_with_nulls();
1638 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1639 let (index, _tmpdir) = do_train(builder, data).await;
1640
1641 let row_ids = row_ids_in_index(&index).await;
1642 assert_eq!(row_ids, vec![0, 1, 2, 3, 4]);
1643
1644 let new_tmpdir = Arc::new(TempDir::default());
1645 let test_store = Arc::new(LanceIndexStore::new(
1646 Arc::new(ObjectStore::local()),
1647 new_tmpdir.obj_path(),
1648 Arc::new(LanceCache::no_cache()),
1649 ));
1650
1651 let remapping = HashMap::from([(2, Some(100)), (3, None), (4, Some(101))]);
1652 index.remap(&remapping, test_store.as_ref()).await.unwrap();
1653
1654 let index = NGramIndex::from_store(test_store, None, &LanceCache::no_cache())
1655 .await
1656 .unwrap();
1657 let row_ids = row_ids_in_index(&index).await;
1658 assert_eq!(row_ids, vec![0, 1, 100, 101]);
1659
1660 let null_posting_list = get_null_posting_list(&index).await;
1661 assert_eq!(null_posting_list, vec![100]);
1662 }
1663
1664 #[test_log::test(tokio::test)]
1665 async fn test_ngram_index_merge() {
1666 let data = simple_data_with_nulls();
1667 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1668 let (index, _tmpdir) = do_train(builder, data).await;
1669
1670 let data = StringArray::from_iter(&[Some("giraffe"), Some("cat"), None]);
1671 let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64 + 100));
1672 let schema = Arc::new(Schema::new(vec![
1673 Field::new(VALUE_COLUMN_NAME, DataType::Utf8, true),
1674 Field::new(ROW_ID, DataType::UInt64, false),
1675 ]));
1676 let data =
1677 RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap();
1678 let data = Box::pin(RecordBatchStreamAdapter::new(
1679 schema,
1680 stream::once(std::future::ready(Ok(data))),
1681 ));
1682
1683 let posting_list = get_posting_list_for_trigram(&index, "cat").await;
1684 assert_eq!(posting_list, vec![0, 4]);
1685
1686 let new_tmpdir = Arc::new(TempDir::default());
1687 let test_store = Arc::new(LanceIndexStore::new(
1688 Arc::new(ObjectStore::local()),
1689 new_tmpdir.obj_path(),
1690 Arc::new(LanceCache::no_cache()),
1691 ));
1692
1693 index.update(data, test_store.as_ref(), None).await.unwrap();
1694
1695 let index = NGramIndex::from_store(test_store, None, &LanceCache::no_cache())
1696 .await
1697 .unwrap();
1698 let row_ids = row_ids_in_index(&index).await;
1699 assert_eq!(row_ids, vec![0, 1, 2, 3, 4, 100, 101, 102]);
1700
1701 let posting_list = get_posting_list_for_trigram(&index, "cat").await;
1702 assert_eq!(posting_list, vec![0, 4, 101]);
1703
1704 let posting_list = get_posting_list_for_trigram(&index, "ffe").await;
1705 assert_eq!(posting_list, vec![100]);
1706
1707 let posting_list = get_null_posting_list(&index).await;
1708 assert_eq!(posting_list, vec![2, 3, 102]);
1709 }
1710
1711 #[test_log::test(tokio::test)]
1712 async fn test_ngram_index_with_spill() {
1713 let (data, schema) = lance_datagen::gen_batch()
1714 .col(
1715 VALUE_COLUMN_NAME,
1716 lance_datagen::array::rand_utf8(ByteCount::from(50), false),
1717 )
1718 .col(ROW_ID, lance_datagen::array::step::<UInt64Type>())
1719 .into_reader_stream(RowCount::from(128), BatchCount::from(32));
1720
1721 let data = Box::pin(RecordBatchStreamAdapter::new(
1722 schema,
1723 data.map_err(|arrow_err| DataFusionError::ArrowError(Box::new(arrow_err), None)),
1724 ));
1725
1726 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions {
1727 tokens_per_spill: 100,
1728 })
1729 .unwrap();
1730
1731 let (index, _tmpdir) = do_train(builder, data).await;
1732
1733 assert_eq!(index.tokens.len(), 29012);
1734 }
1735}