1use std::any::Any;
5use std::collections::BTreeMap;
6use std::iter::once;
7use std::time::Instant;
8use std::{collections::HashMap, sync::Arc};
9
10use super::lance_format::LanceIndexStore;
11use super::{
12 AnyQuery, BuiltinIndexType, IndexReader, IndexStore, IndexWriter, MetricsCollector,
13 ScalarIndex, ScalarIndexParams, SearchResult, TextQuery,
14};
15use crate::frag_reuse::FragReuseIndex;
16use crate::metrics::NoOpMetricsCollector;
17use crate::pbold;
18use crate::scalar::expression::{ScalarQueryParser, TextQueryParser};
19use crate::scalar::registry::{
20 DefaultTrainingRequest, ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest,
21 VALUE_COLUMN_NAME,
22};
23use crate::scalar::{CreatedIndex, UpdateCriteria};
24use crate::vector::VectorIndex;
25use crate::{Index, IndexType};
26use arrow::array::{AsArray, UInt32Builder};
27use arrow::datatypes::{UInt32Type, UInt64Type};
28use arrow_array::{BinaryArray, RecordBatch, UInt32Array};
29use arrow_schema::{DataType, Field, Schema, SchemaRef};
30use async_trait::async_trait;
31use datafusion::execution::SendableRecordBatchStream;
32use deepsize::DeepSizeOf;
33use futures::{FutureExt, Stream, StreamExt, TryStreamExt, stream};
34use lance_arrow::iter_str_array;
35use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
36use lance_core::error::LanceOptionExt;
37use lance_core::utils::address::RowAddress;
38use lance_core::utils::tempfile::TempDir;
39use lance_core::utils::tokio::get_num_compute_intensive_cpus;
40use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS};
41use lance_core::{Error, utils::mask::RowAddrTreeMap};
42use lance_core::{ROW_ID, Result};
43use lance_io::object_store::ObjectStore;
44use lance_tokenizer::{
45 AlphaNumOnlyFilter, AsciiFoldingFilter, LowerCaser, NgramTokenizer, RawTokenizer, TextAnalyzer,
46};
47use log::info;
48use roaring::{RoaringBitmap, RoaringTreemap};
49use serde::Serialize;
50use tracing::instrument;
51
52const TOKENS_COL: &str = "tokens";
53const POSTING_LIST_COL: &str = "posting_list";
54const POSTINGS_FILENAME: &str = "ngram_postings.lance";
55const NGRAM_INDEX_VERSION: u32 = 0;
56
57use std::sync::LazyLock;
58
59pub static TOKENS_FIELD: LazyLock<Field> =
60 LazyLock::new(|| Field::new(TOKENS_COL, DataType::UInt32, true));
61pub static POSTINGS_FIELD: LazyLock<Field> =
62 LazyLock::new(|| Field::new(POSTING_LIST_COL, DataType::Binary, false));
63pub static POSTINGS_SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
64 Arc::new(Schema::new(vec![
65 TOKENS_FIELD.clone(),
66 POSTINGS_FIELD.clone(),
67 ]))
68});
69pub static TEXT_PREPPER: LazyLock<TextAnalyzer> = LazyLock::new(|| {
70 TextAnalyzer::builder(RawTokenizer::default())
71 .filter(LowerCaser)
72 .filter(AsciiFoldingFilter)
73 .build()
74});
75pub static NGRAM_TOKENIZER: LazyLock<TextAnalyzer> = LazyLock::new(|| {
77 TextAnalyzer::builder(NgramTokenizer::all_ngrams(3, 3).unwrap())
78 .filter(AlphaNumOnlyFilter)
79 .build()
80});
81
82fn tokenize_visitor(tokenizer: &TextAnalyzer, text: &str, mut visitor: impl FnMut(&String)) {
84 let mut prepper = TEXT_PREPPER.clone();
92 let mut tokenizer = tokenizer.clone();
93 let mut raw_stream = prepper.token_stream(text);
94 while raw_stream.advance() {
95 let mut token_stream = tokenizer.token_stream(&raw_stream.token().text);
96 while token_stream.advance() {
97 visitor(&token_stream.token().text);
98 }
99 }
100}
101
102const ALPHA_SPAN: usize = 37;
103const MAX_TOKEN: usize = ALPHA_SPAN.pow(2) + ALPHA_SPAN;
104const MIN_TOKEN: usize = 0;
105const NGRAM_N: usize = 3;
106
107fn ngram_to_token(ngram: &str, ngram_length: usize) -> u32 {
128 let mut token = 0;
129 for (idx, byte) in ngram.bytes().enumerate() {
131 let pos = if byte <= b'9' {
132 byte - b'0'
133 } else if byte <= b'z' {
134 byte - b'a' + 10
135 } else {
136 unreachable!()
137 } + 1;
138 debug_assert!(pos < ALPHA_SPAN as u8);
139 let mult = ALPHA_SPAN.pow(ngram_length as u32 - idx as u32 - 1) as u32;
140 token += pos as u32 * mult;
141 }
142 token
143}
144
145#[derive(Serialize)]
147struct NGramStatistics {
148 num_ngrams: usize,
149}
150
151#[derive(Debug)]
153pub struct NGramPostingList {
154 bitmap: RoaringTreemap,
155}
156
157impl DeepSizeOf for NGramPostingList {
158 fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
159 self.bitmap.serialized_size()
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct NGramPostingListKey {
166 pub row_offset: u32,
167}
168
169impl CacheKey for NGramPostingListKey {
170 type ValueType = NGramPostingList;
171
172 fn key(&self) -> std::borrow::Cow<'_, str> {
173 format!("posting-list-{}", self.row_offset).into()
174 }
175
176 fn type_name() -> &'static str {
177 "NGramPostingList"
178 }
179}
180
181impl NGramPostingList {
182 fn try_from_batch(
183 batch: RecordBatch,
184 frag_reuse_index: Option<Arc<FragReuseIndex>>,
185 ) -> Result<Self> {
186 let bitmap_bytes = batch.column(0).as_binary::<i32>().value(0);
187 let mut bitmap = RoaringTreemap::deserialize_from(bitmap_bytes)
188 .map_err(|e| Error::internal(format!("Error deserializing ngram list: {}", e)))?;
189 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
190 bitmap = frag_reuse_index_ref.remap_row_ids_roaring_tree_map(&bitmap);
191 }
192 Ok(Self { bitmap })
193 }
194
195 fn intersect<'a>(lists: impl IntoIterator<Item = &'a Self>) -> RoaringTreemap {
196 let mut iter = lists.into_iter();
197 let mut result = iter
198 .next()
199 .map(|list| list.bitmap.clone())
200 .unwrap_or_default();
201 for list in iter {
202 result &= &list.bitmap;
203 }
204 result
205 }
206}
207
208struct NGramPostingListReader {
210 reader: Arc<dyn IndexReader>,
211 frag_reuse_index: Option<Arc<FragReuseIndex>>,
212 index_cache: WeakLanceCache,
213}
214
215impl DeepSizeOf for NGramPostingListReader {
216 fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
217 0
218 }
219}
220
221impl std::fmt::Debug for NGramPostingListReader {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 f.debug_struct("NGramListReader").finish()
224 }
225}
226
227impl NGramPostingListReader {
228 #[instrument(level = "debug", skip(self, metrics))]
229 pub async fn ngram_list(
230 &self,
231 row_offset: u32,
232 metrics: &dyn MetricsCollector,
233 ) -> Result<Arc<NGramPostingList>> {
234 self.index_cache.get_or_insert_with_key(NGramPostingListKey { row_offset }, || async move {
235 metrics.record_part_load();
236 tracing::info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="ngram", part_id=row_offset);
237 let batch = self
238 .reader
239 .read_range(
240 row_offset as usize..row_offset as usize + 1,
241 Some(&[POSTING_LIST_COL]),
242 )
243 .await?;
244 NGramPostingList::try_from_batch(batch, self.frag_reuse_index.clone())
245 }).await
246 }
247}
248
249pub struct NGramIndex {
264 tokens: HashMap<u32, u32>,
266 list_reader: Arc<NGramPostingListReader>,
268 tokenizer: TextAnalyzer,
273 io_parallelism: usize,
274 store: Arc<dyn IndexStore>,
276}
277
278impl std::fmt::Debug for NGramIndex {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 f.debug_struct("NGramIndex")
281 .field("tokens", &self.tokens)
282 .field("list_reader", &self.list_reader)
283 .finish()
284 }
285}
286
287impl DeepSizeOf for NGramIndex {
288 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
289 self.tokens.deep_size_of_children(context)
290 }
291}
292
293impl NGramIndex {
294 async fn from_store(
295 store: Arc<dyn IndexStore>,
296 frag_reuse_index: Option<Arc<FragReuseIndex>>,
297 index_cache: &LanceCache,
298 ) -> Result<Self> {
299 let tokens = store.open_index_file(POSTINGS_FILENAME).await?;
300 let tokens = tokens
301 .read_range(0..tokens.num_rows(), Some(&[TOKENS_COL]))
302 .await?;
303
304 let tokens_map = HashMap::from_iter(
305 tokens
306 .column(0)
307 .as_primitive::<UInt32Type>()
308 .values()
309 .iter()
310 .copied()
311 .enumerate()
312 .map(|(idx, token)| (token, idx as u32)),
313 );
314
315 let posting_reader = Arc::new(NGramPostingListReader {
316 reader: store.open_index_file(POSTINGS_FILENAME).await?,
317 frag_reuse_index,
318 index_cache: WeakLanceCache::from(index_cache),
319 });
320
321 Ok(Self {
322 io_parallelism: store.io_parallelism(),
323 tokens: tokens_map,
324 list_reader: posting_reader,
325 tokenizer: NGRAM_TOKENIZER.clone(),
326 store,
327 })
328 }
329
330 fn remap_batch(
331 &self,
332 batch: RecordBatch,
333 mapping: &HashMap<u64, Option<u64>>,
334 ) -> Result<RecordBatch> {
335 let posting_lists_array = batch
336 .column_by_name(POSTING_LIST_COL)
337 .expect_ok()?
338 .as_binary::<i32>();
339
340 let new_posting_lists = posting_lists_array
341 .iter()
342 .map(|posting_list| {
343 let posting_list = posting_list.unwrap();
344 let posting_list = RoaringTreemap::deserialize_from(posting_list)?;
345 let new_posting_list =
346 RoaringTreemap::from_iter(posting_list.into_iter().filter_map(|row_id| {
347 match mapping.get(&row_id) {
348 Some(Some(new_row_id)) => Some(*new_row_id),
349 Some(None) => None,
350 None => Some(row_id),
351 }
352 }));
353 let mut buf = Vec::with_capacity(new_posting_list.serialized_size());
354 new_posting_list.serialize_into(&mut buf)?;
355 Ok(buf)
356 })
357 .collect::<Result<Vec<_>>>()?;
358
359 let new_posting_lists_array = BinaryArray::from_iter_values(new_posting_lists);
360
361 Ok(RecordBatch::try_new(
362 POSTINGS_SCHEMA.clone(),
363 vec![
364 batch.column_by_name(TOKENS_COL).expect_ok()?.clone(),
365 Arc::new(new_posting_lists_array),
366 ],
367 )?)
368 }
369
370 async fn load(
371 store: Arc<dyn IndexStore>,
372 frag_reuse_index: Option<Arc<FragReuseIndex>>,
373 index_cache: &LanceCache,
374 ) -> Result<Arc<Self>>
375 where
376 Self: Sized,
377 {
378 Ok(Arc::new(
379 Self::from_store(store, frag_reuse_index, index_cache).await?,
380 ))
381 }
382}
383
384#[async_trait]
385impl Index for NGramIndex {
386 fn as_any(&self) -> &dyn Any {
387 self
388 }
389
390 fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
391 self
392 }
393
394 fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn VectorIndex>> {
395 Err(Error::invalid_input_source(
396 "NGramIndex is not a vector index".into(),
397 ))
398 }
399
400 fn statistics(&self) -> Result<serde_json::Value> {
401 let ngram_stats = NGramStatistics {
402 num_ngrams: self.tokens.len(),
403 };
404 serde_json::to_value(ngram_stats)
405 .map_err(|e| Error::internal(format!("Error serializing statistics: {}", e)))
406 }
407
408 async fn prewarm(&self) -> Result<()> {
409 Ok(())
411 }
412
413 fn index_type(&self) -> IndexType {
414 IndexType::NGram
415 }
416
417 async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
418 let mut frag_ids = RoaringBitmap::new();
419 for row_offset in self.tokens.values() {
420 let list = self
421 .list_reader
422 .ngram_list(*row_offset, &NoOpMetricsCollector)
423 .await?;
424 frag_ids.extend(
425 list.bitmap
426 .iter()
427 .map(|row_addr| RowAddress::from(row_addr).fragment_id()),
428 );
429 }
430 Ok(frag_ids)
431 }
432}
433
434#[async_trait]
435impl ScalarIndex for NGramIndex {
436 async fn search(
437 &self,
438 query: &dyn AnyQuery,
439 metrics: &dyn MetricsCollector,
440 ) -> Result<SearchResult> {
441 let query = query
442 .as_any()
443 .downcast_ref::<TextQuery>()
444 .ok_or_else(|| Error::invalid_input_source("Query is not a TextQuery".into()))?;
445 match query {
446 TextQuery::StringContains(substr) => {
447 if substr.len() < NGRAM_N {
448 return Ok(SearchResult::at_least(RowAddrTreeMap::new()));
450 }
451
452 let mut row_offsets = Vec::with_capacity(substr.len() * 3);
453 let mut missing = false;
454 tokenize_visitor(&self.tokenizer, substr, |ngram| {
455 let token = ngram_to_token(ngram, NGRAM_N);
456 if let Some(row_offset) = self.tokens.get(&token) {
457 row_offsets.push(*row_offset);
458 } else {
459 missing = true;
460 }
461 });
462 if missing {
464 return Ok(SearchResult::exact(RowAddrTreeMap::new()));
465 }
466 let posting_lists = futures::stream::iter(
467 row_offsets
468 .into_iter()
469 .map(|row_offset| self.list_reader.ngram_list(row_offset, metrics)),
470 )
471 .buffer_unordered(self.io_parallelism)
472 .try_collect::<Vec<_>>()
473 .await?;
474 metrics.record_comparisons(posting_lists.len());
475 let list_refs = posting_lists.iter().map(|list| list.as_ref());
476 let row_ids = NGramPostingList::intersect(list_refs);
477 Ok(SearchResult::at_most(RowAddrTreeMap::from(row_ids)))
478 }
479 }
480 }
481
482 fn can_remap(&self) -> bool {
483 true
484 }
485
486 async fn remap(
487 &self,
488 mapping: &HashMap<u64, Option<u64>>,
489 dest_store: &dyn IndexStore,
490 ) -> Result<CreatedIndex> {
491 let reader = self.store.open_index_file(POSTINGS_FILENAME).await?;
492 let mut writer = dest_store
493 .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone())
494 .await?;
495
496 let mut offset = 0;
497 let num_rows = reader.num_rows();
498 const BATCH_SIZE: usize = 64;
499 while offset < num_rows {
500 let batch_size = BATCH_SIZE.min(num_rows - offset);
501 let batch = reader.read_range(offset..offset + batch_size, None).await?;
502 let batch = self.remap_batch(batch, mapping)?;
503 writer.write_record_batch(batch).await?;
504 offset += BATCH_SIZE;
505 }
506
507 writer.finish().await?;
508
509 Ok(CreatedIndex {
510 index_details: prost_types::Any::from_msg(&pbold::NGramIndexDetails::default())
511 .unwrap(),
512 index_version: NGRAM_INDEX_VERSION,
513 files: Some(dest_store.list_files_with_sizes().await?),
514 })
515 }
516
517 async fn update(
518 &self,
519 new_data: SendableRecordBatchStream,
520 dest_store: &dyn IndexStore,
521 _old_data_filter: Option<super::OldIndexDataFilter>,
522 ) -> Result<CreatedIndex> {
523 let mut builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default())?;
524 let spill_files = builder.train(new_data).await?;
525
526 builder
527 .write_index(dest_store, spill_files, Some(self.store.clone()))
528 .await?;
529
530 Ok(CreatedIndex {
531 index_details: prost_types::Any::from_msg(&pbold::NGramIndexDetails::default())
532 .unwrap(),
533 index_version: NGRAM_INDEX_VERSION,
534 files: Some(dest_store.list_files_with_sizes().await?),
535 })
536 }
537
538 fn update_criteria(&self) -> UpdateCriteria {
539 UpdateCriteria::only_new_data(TrainingCriteria::new(TrainingOrdering::None).with_row_id())
540 }
541
542 fn derive_index_params(&self) -> Result<ScalarIndexParams> {
543 Ok(ScalarIndexParams::for_builtin(BuiltinIndexType::NGram))
544 }
545}
546
547#[derive(Debug, Clone)]
548pub struct NGramIndexBuilderOptions {
549 tokens_per_spill: usize,
550}
551
552static DEFAULT_TOKENS_PER_SPILL: LazyLock<usize> = LazyLock::new(|| {
554 std::env::var("LANCE_NGRAM_TOKENS_PER_SPILL")
555 .unwrap_or_else(|_| "1000000000".to_string())
556 .parse()
557 .expect("failed to parse LANCE_NGRAM_TOKENS_PER_SPILL")
558});
559static DEFAULT_NUM_PARTITIONS: LazyLock<usize> = LazyLock::new(|| {
565 std::env::var("LANCE_NGRAM_NUM_PARTITIONS")
566 .map(|s| s.parse().expect("failed to parse LANCE_NGRAM_PARALLELISM"))
567 .unwrap_or((get_num_compute_intensive_cpus() * 4).max(128))
568});
569static DEFAULT_TOKENIZE_PARALLELISM: LazyLock<usize> = LazyLock::new(|| {
571 std::env::var("LANCE_NGRAM_TOKENIZE_PARALLELISM")
572 .map(|s| {
573 s.parse()
574 .expect("failed to parse LANCE_NGRAM_TOKENIZE_PARALLELISM")
575 })
576 .unwrap_or(8)
577});
578
579impl Default for NGramIndexBuilderOptions {
580 fn default() -> Self {
581 Self {
582 tokens_per_spill: *DEFAULT_TOKENS_PER_SPILL,
583 }
584 }
585}
586
587struct NGramIndexSpillState {
591 tokens: UInt32Array,
592 bitmaps: Vec<RoaringTreemap>,
593}
594
595impl NGramIndexSpillState {
596 fn try_from_batch(batch: RecordBatch) -> Result<Self> {
597 let tokens = batch
598 .column_by_name(TOKENS_COL)
599 .expect_ok()?
600 .as_primitive::<UInt32Type>()
601 .clone();
602 let postings = batch
603 .column_by_name(POSTING_LIST_COL)
604 .expect_ok()?
605 .as_binary::<i32>();
606
607 let bitmaps = postings
608 .into_iter()
609 .map(|bytes| {
610 RoaringTreemap::deserialize_from(bytes.expect_ok()?)
611 .map_err(|e| Error::internal(format!("Error deserializing ngram list: {}", e)))
612 })
613 .collect::<Result<Vec<_>>>()?;
614
615 Ok(Self { tokens, bitmaps })
616 }
617
618 fn try_into_batch(self) -> Result<RecordBatch> {
619 let bitmap_array = BinaryArray::from_iter_values(self.bitmaps.into_iter().map(|bitmap| {
620 let mut buf = Vec::with_capacity(bitmap.serialized_size());
621 bitmap.serialize_into(&mut buf).unwrap();
622 buf
623 }));
624 Ok(RecordBatch::try_new(
625 POSTINGS_SCHEMA.clone(),
626 vec![Arc::new(self.tokens), Arc::new(bitmap_array)],
627 )?)
628 }
629}
630
631struct NGramIndexBuildState {
634 tokens_map: BTreeMap<u32, RoaringTreemap>,
635}
636
637impl NGramIndexBuildState {
638 fn starting() -> Self {
639 Self {
640 tokens_map: BTreeMap::new(),
641 }
642 }
643
644 fn take(&mut self) -> Self {
645 let mut taken = Self::starting();
646 std::mem::swap(&mut self.tokens_map, &mut taken.tokens_map);
647 taken
648 }
649
650 fn into_spill(self) -> NGramIndexSpillState {
651 let tokens = UInt32Array::from_iter_values(self.tokens_map.keys().copied());
653 let bitmaps = Vec::from_iter(self.tokens_map.into_values());
654
655 NGramIndexSpillState { bitmaps, tokens }
656 }
657}
658
659pub struct NGramIndexBuilder {
672 tokenizer: TextAnalyzer,
673 options: NGramIndexBuilderOptions,
674 tmpdir: Arc<TempDir>,
675 spill_store: Arc<dyn IndexStore>,
676
677 tokens_seen: usize,
678 worker_number: usize,
679 has_flushed: bool,
680
681 state: NGramIndexBuildState,
682}
683
684impl NGramIndexBuilder {
685 pub fn try_new(options: NGramIndexBuilderOptions) -> Result<Self> {
686 Self::from_state(NGramIndexBuildState::starting(), options)
687 }
688
689 fn clone_worker(&self, worker_number: usize) -> Self {
690 let mut bitmaps = Vec::with_capacity(36 * 36 * 36 + 1);
691 bitmaps.push(RoaringTreemap::new());
693 Self {
694 tokenizer: self.tokenizer.clone(),
695 state: NGramIndexBuildState::starting(),
696 tmpdir: self.tmpdir.clone(),
697 spill_store: self.spill_store.clone(),
698 options: self.options.clone(),
699 tokens_seen: 0,
700 worker_number,
701 has_flushed: false,
702 }
703 }
704
705 fn from_state(state: NGramIndexBuildState, options: NGramIndexBuilderOptions) -> Result<Self> {
706 let tokenizer = NGRAM_TOKENIZER.clone();
707
708 let tmpdir = Arc::new(TempDir::default());
709 let spill_store = Arc::new(LanceIndexStore::new(
710 Arc::new(ObjectStore::local()),
711 tmpdir.obj_path(),
712 Arc::new(LanceCache::no_cache()),
713 ));
714
715 Ok(Self {
716 tokenizer,
717 state,
718 tmpdir,
719 spill_store,
720 options,
721 tokens_seen: 0,
722 worker_number: 0,
723 has_flushed: false,
724 })
725 }
726
727 fn validate_schema(schema: &Schema) -> Result<()> {
728 if schema.fields().len() != 2 {
729 return Err(Error::invalid_input_source(
730 "Ngram index schema must have exactly two fields".into(),
731 ));
732 }
733 let values_field = schema.field_with_name(VALUE_COLUMN_NAME)?;
734 if *values_field.data_type() != DataType::Utf8
735 && *values_field.data_type() != DataType::LargeUtf8
736 {
737 return Err(Error::invalid_input_source(
738 "First field in ngram index schema must be of type Utf8/LargeUtf8".into(),
739 ));
740 }
741 let row_id_field = schema.field_with_name(ROW_ID)?;
742 if *row_id_field.data_type() != DataType::UInt64 {
743 return Err(Error::invalid_input_source(
744 "Second field in ngram index schema must be of type UInt64".into(),
745 ));
746 }
747 Ok(())
748 }
749
750 async fn process_batch(&mut self, tokens_and_ids: Vec<(u32, u64)>) -> Result<()> {
751 let mut tokens_seen = 0;
752 for (token, row_id) in tokens_and_ids {
753 tokens_seen += 1;
754 self.state
759 .tokens_map
760 .entry(token)
761 .or_default()
762 .insert(row_id);
763 }
764 self.tokens_seen += tokens_seen;
765 if self.tokens_seen >= self.options.tokens_per_spill {
766 let state = self.state.take();
767 self.flush(state).await?;
768 }
769 Ok(())
770 }
771
772 fn spill_filename(id: usize) -> String {
773 format!("spill-{}.lance", id)
774 }
775
776 fn tmp_spill_filename(id: usize) -> String {
777 format!("spill-{}.lance.tmp", id)
778 }
779
780 async fn flush(&mut self, state: NGramIndexBuildState) -> Result<bool> {
781 if self.tokens_seen == 0 {
782 assert!(state.tokens_map.is_empty());
783 return Ok(self.has_flushed);
784 }
785 self.tokens_seen = 0;
786 let spill_state = state.into_spill();
787 let flush_start = Instant::now();
788 debug_assert_ne!(self.worker_number, 0);
790 if self.has_flushed {
791 info!("Merging flush for worker {}", self.worker_number);
792 let mut writer = self
794 .spill_store
795 .new_index_file(
796 &Self::tmp_spill_filename(self.worker_number),
797 POSTINGS_SCHEMA.clone(),
798 )
799 .await?;
800
801 let left_stream = stream::once(std::future::ready(Ok(spill_state)));
802 let right_stream =
803 Self::stream_spill(self.spill_store.clone(), self.worker_number).await?;
804 Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?;
805 drop(writer);
806 self.spill_store
807 .rename_index_file(
808 &Self::tmp_spill_filename(self.worker_number),
809 &Self::spill_filename(self.worker_number),
810 )
811 .await?;
812 } else {
813 info!("Initial flush for worker {}", self.worker_number);
815 self.has_flushed = true;
816 let writer = self
817 .spill_store
818 .new_index_file(
819 &Self::spill_filename(self.worker_number),
820 POSTINGS_SCHEMA.clone(),
821 )
822 .await?;
823 self.write(writer, spill_state).await?;
824 }
825 let flush_time = flush_start.elapsed();
826 info!(
827 "Flushed worker {} in {}ms",
828 self.worker_number,
829 flush_time.as_millis()
830 );
831 Ok(true)
832 }
833
834 fn tokenize_and_partition(
835 tokenizer: &TextAnalyzer,
836 batch: RecordBatch,
837 num_workers: usize,
838 ) -> Result<Vec<Vec<(u32, u64)>>> {
839 let text_iter = iter_str_array(batch.column_by_name(VALUE_COLUMN_NAME).expect_ok()?);
840 let row_id_col = batch
841 .column_by_name(ROW_ID)
842 .expect_ok()?
843 .as_primitive::<UInt64Type>();
844 let mut partitions = vec![Vec::with_capacity(batch.num_rows() * 1000); num_workers];
846 let divisor = (MAX_TOKEN - MIN_TOKEN) / num_workers;
847 for (text, row_id) in text_iter.zip(row_id_col.values()) {
848 if let Some(text) = text {
849 tokenize_visitor(tokenizer, text, |token| {
850 let token = ngram_to_token(token, NGRAM_N);
851 let partition_id = (token as usize).saturating_sub(MIN_TOKEN) / divisor;
852 partitions[partition_id % num_workers].push((token, *row_id));
853 });
854 } else {
855 partitions[0].push((0, *row_id));
856 }
857 }
858 Ok(partitions)
859 }
860
861 pub async fn train(&mut self, data: SendableRecordBatchStream) -> Result<Vec<usize>> {
862 let schema = data.schema();
863 Self::validate_schema(schema.as_ref())?;
864
865 let num_workers = *DEFAULT_NUM_PARTITIONS;
866 let mut senders = Vec::with_capacity(num_workers);
867 let mut builders = Vec::with_capacity(num_workers);
868 for worker_idx in 0..num_workers {
869 let (send, mut recv) = tokio::sync::mpsc::channel(2);
870 senders.push(send);
871
872 let mut builder = self.clone_worker(worker_idx + 1);
873 let future = tokio::spawn(async move {
874 while let Some(partition) = recv.recv().await {
875 builder.process_batch(partition).await?;
876 }
877 Result::Ok(builder)
878 });
879 builders.push(future);
880 }
881
882 let mut partitions_stream = data
883 .and_then(|batch| {
884 let tokenizer = self.tokenizer.clone();
885 std::future::ready(Ok(tokio::task::spawn(async move {
886 Ok(Self::tokenize_and_partition(
887 &tokenizer,
888 batch,
889 num_workers,
890 )?)
891 })
892 .map(|res| res.unwrap())))
893 })
894 .try_buffer_unordered(*DEFAULT_TOKENIZE_PARALLELISM);
895
896 while let Some(partitions) = partitions_stream.try_next().await? {
897 for (part_idx, partition) in partitions.into_iter().enumerate() {
898 senders[part_idx].send(partition).await.unwrap();
899 }
900 }
901
902 std::mem::drop(senders);
903 let builders = futures::future::try_join_all(builders).await?;
904
905 let mut to_spill = Vec::with_capacity(builders.len());
909
910 for builder in builders {
911 let mut builder = builder?;
912 let state = builder.state.take();
913 if builder.flush(state).await? {
914 to_spill.push(builder.worker_number);
915 }
916 }
917
918 Ok(to_spill)
919 }
920
921 async fn write(
922 &mut self,
923 mut writer: Box<dyn IndexWriter>,
924 state: NGramIndexSpillState,
925 ) -> Result<()> {
926 writer.write_record_batch(state.try_into_batch()?).await?;
927 writer.finish().await?;
928
929 Ok(())
930 }
931
932 async fn stream_spill_reader(
933 reader: Arc<dyn IndexReader>,
934 ) -> Result<impl Stream<Item = Result<NGramIndexSpillState>>> {
935 let num_rows = reader.num_rows();
936
937 Ok(stream::try_unfold(0, move |offset| {
938 let reader = reader.clone();
939 async move {
940 let batch_size = std::cmp::min(num_rows - offset, 64);
943 if batch_size == 0 {
944 return Ok(None);
945 }
946 let batch = reader.read_range(offset..offset + batch_size, None).await?;
947 let state = NGramIndexSpillState::try_from_batch(batch)?;
948 let new_offset = offset + batch_size;
949 Ok(Some((state, new_offset)))
950 }
951 .boxed()
952 }))
953 }
954
955 async fn stream_spill(
956 spill_store: Arc<dyn IndexStore>,
957 id: usize,
958 ) -> Result<impl Stream<Item = Result<NGramIndexSpillState>>> {
959 let reader = spill_store
960 .open_index_file(&Self::spill_filename(id))
961 .await?;
962 Self::stream_spill_reader(reader).await
963 }
964
965 fn merge_spill_states(
966 left_opt: &mut Option<NGramIndexSpillState>,
967 right_opt: &mut Option<NGramIndexSpillState>,
968 ) -> NGramIndexSpillState {
969 let left = left_opt.take().unwrap();
970 let right = right_opt.take().unwrap();
971
972 let item_capacity = left.tokens.len() + right.tokens.len();
973 let mut merged_tokens = UInt32Builder::with_capacity(item_capacity);
974 let mut merged_bitmaps = Vec::with_capacity(left.bitmaps.len() + right.bitmaps.len());
975
976 let mut left_tokens = left.tokens.values().iter().copied();
977 let mut left_bitmaps = left.bitmaps.into_iter();
978 let mut right_tokens = right.tokens.values().iter().copied();
979 let mut right_bitmaps = right.bitmaps.into_iter();
980
981 let mut left_token = left_tokens.next();
982 let mut left_bitmap = left_bitmaps.next();
983 let mut right_token = right_tokens.next();
984 let mut right_bitmap = right_bitmaps.next();
985
986 while left_token.is_some() && right_token.is_some() {
987 let left_token_val = left_token.unwrap();
988 let right_token_val = right_token.unwrap();
989 match left_token_val.cmp(&right_token_val) {
990 std::cmp::Ordering::Less => {
991 merged_tokens.append_value(left_token_val);
992 merged_bitmaps.push(left_bitmap.unwrap());
993 left_token = left_tokens.next();
994 left_bitmap = left_bitmaps.next();
995 }
996 std::cmp::Ordering::Greater => {
997 merged_tokens.append_value(right_token_val);
998 merged_bitmaps.push(right_bitmap.unwrap());
999 right_token = right_tokens.next();
1000 right_bitmap = right_bitmaps.next();
1001 }
1002 std::cmp::Ordering::Equal => {
1003 merged_tokens.append_value(left_token_val);
1004 merged_bitmaps.push(left_bitmap.unwrap() | &right_bitmap.unwrap());
1005 left_token = left_tokens.next();
1006 left_bitmap = left_bitmaps.next();
1007 right_token = right_tokens.next();
1008 right_bitmap = right_bitmaps.next();
1009 }
1010 }
1011 }
1012
1013 let collect_remaining = |cur_token, tokens, cur_bitmap, bitmaps| {
1014 let tokens = UInt32Array::from_iter_values(once(cur_token).chain(tokens));
1015 let bitmaps = once(cur_bitmap).chain(bitmaps).collect::<Vec<_>>();
1016 NGramIndexSpillState { tokens, bitmaps }
1017 };
1018
1019 if let Some(left_token) = left_token {
1020 *left_opt = Some(collect_remaining(
1021 left_token,
1022 left_tokens,
1023 left_bitmap.unwrap(),
1024 left_bitmaps,
1025 ));
1026 } else {
1027 *left_opt = None;
1028 }
1029 if let Some(right_token) = right_token {
1030 *right_opt = Some(collect_remaining(
1031 right_token,
1032 right_tokens,
1033 right_bitmap.unwrap(),
1034 right_bitmaps,
1035 ));
1036 } else {
1037 *right_opt = None;
1038 }
1039
1040 NGramIndexSpillState {
1041 tokens: merged_tokens.finish(),
1042 bitmaps: merged_bitmaps,
1043 }
1044 }
1045
1046 async fn merge_spill_streams(
1047 mut left_stream: impl Stream<Item = Result<NGramIndexSpillState>> + Unpin,
1048 mut right_stream: impl Stream<Item = Result<NGramIndexSpillState>> + Unpin,
1049 writer: &mut dyn IndexWriter,
1050 ) -> Result<()> {
1051 let mut left_state = left_stream.try_next().await?;
1052 let mut right_state = right_stream.try_next().await?;
1053
1054 while left_state.is_some() || right_state.is_some() {
1055 if left_state.is_none() {
1056 let state = right_state.take().expect_ok()?;
1058 writer.write_record_batch(state.try_into_batch()?).await?;
1059 while let Some(state) = right_stream.try_next().await? {
1060 writer.write_record_batch(state.try_into_batch()?).await?;
1061 }
1062 } else if right_state.is_none() {
1063 let state = left_state.take().expect_ok()?;
1065 writer.write_record_batch(state.try_into_batch()?).await?;
1066 while let Some(state) = left_stream.try_next().await? {
1067 writer.write_record_batch(state.try_into_batch()?).await?;
1068 }
1069 } else {
1070 let merged = Self::merge_spill_states(&mut left_state, &mut right_state);
1072 writer.write_record_batch(merged.try_into_batch()?).await?;
1073 if left_state.is_none() {
1074 left_state = left_stream.try_next().await?;
1075 }
1076 if right_state.is_none() {
1077 right_state = right_stream.try_next().await?;
1078 }
1079 }
1080 }
1081
1082 writer.finish().await
1083 }
1084
1085 async fn merge_spill_files(
1086 spill_store: Arc<dyn IndexStore>,
1087 index_of_left: usize,
1088 index_of_right: usize,
1089 output_index: usize,
1090 ) -> Result<()> {
1091 info!(
1093 "Merge spill files {} and {} into {}",
1094 index_of_left, index_of_right, output_index
1095 );
1096
1097 let mut writer = spill_store
1098 .new_index_file(&Self::spill_filename(output_index), POSTINGS_SCHEMA.clone())
1099 .await?;
1100
1101 let (left_stream, right_stream) = futures::try_join!(
1102 Self::stream_spill(spill_store.clone(), index_of_left),
1103 Self::stream_spill(spill_store.clone(), index_of_right)
1104 )?;
1105
1106 Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?;
1107
1108 spill_store
1109 .delete_index_file(&Self::spill_filename(index_of_left))
1110 .await?;
1111 spill_store
1112 .delete_index_file(&Self::spill_filename(index_of_right))
1113 .await?;
1114
1115 Ok(())
1116 }
1117
1118 async fn merge_spills(&mut self, mut spill_files: Vec<usize>) -> Result<usize> {
1125 info!(
1126 "Merging {} index files into one combined index",
1127 spill_files.len()
1128 );
1129
1130 let mut spill_counter = spill_files.iter().max().expect_ok()? + 1;
1131 while spill_files.len() > 1 {
1132 let mut new_spills = Vec::with_capacity(spill_files.len() / 2);
1133 while spill_files.len() >= 2 {
1134 let left = spill_files.pop().expect_ok()?;
1135 let right = spill_files.pop().expect_ok()?;
1136 new_spills.push(tokio::spawn(Self::merge_spill_files(
1137 self.spill_store.clone(),
1138 left,
1139 right,
1140 spill_counter + new_spills.len(),
1141 )));
1142 }
1143 for i in 0..new_spills.len() {
1144 spill_files.push(spill_counter + i);
1145 }
1146 spill_counter += new_spills.len();
1147 futures::future::try_join_all(new_spills).await?;
1148 }
1149
1150 spill_files.pop().expect_ok()
1151 }
1152
1153 async fn merge_old_index(
1154 &mut self,
1155 new_data_num: usize,
1156 old_index: Arc<dyn IndexStore>,
1157 ) -> Result<usize> {
1158 info!("Merging old index into new index");
1159 let final_num = new_data_num + 1;
1160
1161 let mut writer = self
1162 .spill_store
1163 .new_index_file(&Self::spill_filename(final_num), POSTINGS_SCHEMA.clone())
1164 .await?;
1165
1166 let left_stream = Self::stream_spill(self.spill_store.clone(), new_data_num).await?;
1167 let old_reader = old_index.open_index_file(POSTINGS_FILENAME).await?;
1168 let right_stream = Self::stream_spill_reader(old_reader).await?;
1169
1170 Self::merge_spill_streams(left_stream, right_stream, writer.as_mut()).await?;
1171
1172 self.spill_store
1173 .delete_index_file(&Self::spill_filename(new_data_num))
1174 .await?;
1175
1176 Ok(final_num)
1177 }
1178
1179 pub async fn write_index(
1180 mut self,
1181 store: &dyn IndexStore,
1182 spill_files: Vec<usize>,
1183 old_index: Option<Arc<dyn IndexStore>>,
1184 ) -> Result<()> {
1185 let mut writer = store
1186 .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone())
1187 .await?;
1188
1189 if spill_files.is_empty() {
1190 if let Some(old_index) = old_index {
1191 old_index.copy_index_file(POSTINGS_FILENAME, store).await?;
1193 } else {
1194 let mut writer = store
1196 .new_index_file(POSTINGS_FILENAME, POSTINGS_SCHEMA.clone())
1197 .await?;
1198 writer.finish().await?;
1199 }
1200 return Ok(());
1201 }
1202
1203 let mut index_to_copy = self.merge_spills(spill_files).await?;
1204
1205 if let Some(old_index) = old_index {
1206 index_to_copy = self.merge_old_index(index_to_copy, old_index).await?;
1207 }
1208
1209 let reader = self
1210 .spill_store
1211 .open_index_file(&Self::spill_filename(index_to_copy))
1212 .await?;
1213
1214 let num_rows = reader.num_rows();
1215 let mut offset = 0;
1216
1217 while offset < num_rows {
1218 let batch_size = std::cmp::min(num_rows - offset, 64);
1219 let batch = reader.read_range(offset..offset + batch_size, None).await?;
1220 writer.write_record_batch(batch).await?;
1221 offset += batch_size;
1222 }
1223
1224 writer.finish().await
1225 }
1226}
1227
1228#[derive(Debug, Default)]
1229pub struct NGramIndexPlugin;
1230
1231impl NGramIndexPlugin {
1232 pub async fn train_ngram_index(
1233 batches_source: SendableRecordBatchStream,
1234 index_store: &dyn IndexStore,
1235 ) -> Result<()> {
1236 let mut builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default())?;
1237
1238 let spill_files = builder.train(batches_source).await?;
1239
1240 builder.write_index(index_store, spill_files, None).await
1241 }
1242}
1243
1244#[async_trait]
1245impl ScalarIndexPlugin for NGramIndexPlugin {
1246 fn name(&self) -> &str {
1247 "NGram"
1248 }
1249
1250 fn new_training_request(
1251 &self,
1252 _params: &str,
1253 field: &Field,
1254 ) -> Result<Box<dyn TrainingRequest>> {
1255 if !matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
1256 return Err(Error::invalid_input_source(format!(
1257 "A ngram index can only be created on a Utf8 or LargeUtf8 field. Column has type {:?}",
1258 field.data_type()
1259 )
1260 .into()));
1261 }
1262 Ok(Box::new(DefaultTrainingRequest::new(
1263 TrainingCriteria::new(TrainingOrdering::None).with_row_id(),
1264 )))
1265 }
1266
1267 fn provides_exact_answer(&self) -> bool {
1268 false
1269 }
1270
1271 fn version(&self) -> u32 {
1272 NGRAM_INDEX_VERSION
1273 }
1274
1275 fn new_query_parser(
1276 &self,
1277 index_name: String,
1278 _index_details: &prost_types::Any,
1279 ) -> Option<Box<dyn ScalarQueryParser>> {
1280 Some(Box::new(TextQueryParser::new(
1281 index_name,
1282 self.name().to_string(),
1283 true,
1284 )))
1285 }
1286
1287 async fn train_index(
1288 &self,
1289 data: SendableRecordBatchStream,
1290 index_store: &dyn IndexStore,
1291 _request: Box<dyn TrainingRequest>,
1292 fragment_ids: Option<Vec<u32>>,
1293 _progress: Arc<dyn crate::progress::IndexBuildProgress>,
1294 ) -> Result<CreatedIndex> {
1295 if fragment_ids.is_some() {
1296 return Err(Error::invalid_input_source(
1297 "NGram index does not support fragment training".into(),
1298 ));
1299 }
1300
1301 Self::train_ngram_index(data, index_store).await?;
1302 Ok(CreatedIndex {
1303 index_details: prost_types::Any::from_msg(&pbold::NGramIndexDetails::default())
1304 .unwrap(),
1305 index_version: NGRAM_INDEX_VERSION,
1306 files: Some(index_store.list_files_with_sizes().await?),
1307 })
1308 }
1309
1310 async fn load_index(
1311 &self,
1312 index_store: Arc<dyn IndexStore>,
1313 _index_details: &prost_types::Any,
1314 frag_reuse_index: Option<Arc<FragReuseIndex>>,
1315 cache: &LanceCache,
1316 ) -> Result<Arc<dyn ScalarIndex>> {
1317 Ok(NGramIndex::load(index_store, frag_reuse_index, cache).await? as Arc<dyn ScalarIndex>)
1318 }
1319}
1320
1321#[cfg(test)]
1322mod tests {
1323 use std::{
1324 collections::{HashMap, HashSet},
1325 sync::Arc,
1326 };
1327
1328 use arrow::datatypes::UInt64Type;
1329 use arrow_array::{Array, RecordBatch, StringArray, UInt64Array};
1330 use arrow_schema::{DataType, Field, Schema};
1331 use datafusion::{
1332 execution::SendableRecordBatchStream, physical_plan::stream::RecordBatchStreamAdapter,
1333 };
1334 use datafusion_common::DataFusionError;
1335 use futures::{TryStreamExt, stream};
1336 use itertools::Itertools;
1337 use lance_core::{
1338 ROW_ID,
1339 cache::LanceCache,
1340 utils::{mask::RowAddrTreeMap, tempfile::TempDir},
1341 };
1342 use lance_datagen::{BatchCount, ByteCount, RowCount};
1343 use lance_io::object_store::ObjectStore;
1344 use lance_tokenizer::TextAnalyzer;
1345
1346 use crate::scalar::{
1347 ScalarIndex, SearchResult, TextQuery,
1348 lance_format::LanceIndexStore,
1349 ngram::{NGramIndex, NGramIndexBuilder, NGramIndexBuilderOptions},
1350 };
1351 use crate::{metrics::NoOpMetricsCollector, scalar::registry::VALUE_COLUMN_NAME};
1352
1353 use super::{NGRAM_TOKENIZER, ngram_to_token, tokenize_visitor};
1354
1355 fn collect_tokens(analyzer: &TextAnalyzer, text: &str) -> Vec<String> {
1356 let mut tokens = Vec::with_capacity(text.len() * 3);
1357 tokenize_visitor(analyzer, text, |token| tokens.push(token.to_owned()));
1358 tokens
1359 }
1360
1361 #[test]
1362 fn test_tokenizer() {
1363 let tokenizer = NGRAM_TOKENIZER.clone();
1364
1365 let tokens = collect_tokens(&tokenizer, "café");
1367 assert_eq!(
1368 tokens,
1369 vec!["caf", "afe"] );
1371
1372 let tokens = collect_tokens(&tokenizer, "a1b2");
1374 assert_eq!(tokens, vec!["a1b", "1b2"]);
1375
1376 let tokens = collect_tokens(&tokenizer, "abc👍b!c24");
1378
1379 assert_eq!(tokens, vec!["abc", "c24"]);
1380
1381 let tokens = collect_tokens(&tokenizer, "anstoß");
1382
1383 assert_eq!(tokens, vec!["ans", "nst", "sto", "tos", "oss"]);
1384
1385 let tokens = collect_tokens(&tokenizer, "ABC");
1387 assert_eq!(tokens, vec!["abc"]);
1388
1389 let tokens = collect_tokens(&tokenizer, "ababab");
1391 assert_eq!(
1394 tokens,
1395 vec!["aba", "bab", "aba", "bab"] );
1397 }
1398
1399 async fn do_train(
1400 mut builder: NGramIndexBuilder,
1401 data: SendableRecordBatchStream,
1402 ) -> (NGramIndex, Arc<TempDir>) {
1403 let spill_files = builder.train(data).await.unwrap();
1404
1405 let tmpdir = Arc::new(TempDir::default());
1406 let test_store = LanceIndexStore::new(
1407 Arc::new(ObjectStore::local()),
1408 tmpdir.obj_path(),
1409 Arc::new(LanceCache::no_cache()),
1410 );
1411
1412 builder
1413 .write_index(&test_store, spill_files, None)
1414 .await
1415 .unwrap();
1416
1417 (
1418 NGramIndex::from_store(Arc::new(test_store), None, &LanceCache::no_cache())
1419 .await
1420 .unwrap(),
1421 tmpdir,
1422 )
1423 }
1424
1425 async fn get_posting_list_for_trigram(index: &NGramIndex, trigram: &str) -> Vec<u64> {
1426 let token = ngram_to_token(trigram, 3);
1427 let row_offset = index.tokens[&token];
1428 let list = index
1429 .list_reader
1430 .ngram_list(row_offset, &NoOpMetricsCollector)
1431 .await
1432 .unwrap();
1433 list.bitmap.iter().sorted().collect()
1434 }
1435
1436 async fn get_null_posting_list(index: &NGramIndex) -> Vec<u64> {
1437 let row_offset = index.tokens[&0];
1438 let list = index
1439 .list_reader
1440 .ngram_list(row_offset, &NoOpMetricsCollector)
1441 .await
1442 .unwrap();
1443 list.bitmap.iter().sorted().collect()
1444 }
1445
1446 #[test_log::test(tokio::test)]
1447 async fn test_basic_ngram_index() {
1448 let data = StringArray::from_iter_values([
1449 "cat",
1450 "dog",
1451 "cat dog",
1452 "dog cat",
1453 "elephant",
1454 "mouse",
1455 "rhino",
1456 "giraffe",
1457 "rhinos nose",
1458 ]);
1459 let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64));
1460 let schema = Arc::new(Schema::new(vec![
1461 Field::new(VALUE_COLUMN_NAME, DataType::Utf8, false),
1462 Field::new(ROW_ID, DataType::UInt64, false),
1463 ]));
1464 let data =
1465 RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap();
1466 let data = Box::pin(RecordBatchStreamAdapter::new(
1467 schema,
1468 stream::once(std::future::ready(Ok(data))),
1469 ));
1470
1471 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1472
1473 let (index, _tmpdir) = do_train(builder, data).await;
1474 assert_eq!(index.tokens.len(), 21);
1475
1476 let res = index
1478 .search(
1479 &TextQuery::StringContains("cat".to_string()),
1480 &NoOpMetricsCollector,
1481 )
1482 .await
1483 .unwrap();
1484
1485 let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([0, 2, 3]));
1486
1487 assert_eq!(expected, res);
1488
1489 let res = index
1491 .search(
1492 &TextQuery::StringContains("nos nos".to_string()),
1493 &NoOpMetricsCollector,
1494 )
1495 .await
1496 .unwrap();
1497 let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8]));
1498 assert_eq!(expected, res);
1499
1500 let res = index
1502 .search(
1503 &TextQuery::StringContains("tdo".to_string()),
1504 &NoOpMetricsCollector,
1505 )
1506 .await
1507 .unwrap();
1508 let expected = SearchResult::exact(RowAddrTreeMap::new());
1509 assert_eq!(expected, res);
1510
1511 let res = index
1513 .search(
1514 &TextQuery::StringContains("inose".to_string()),
1515 &NoOpMetricsCollector,
1516 )
1517 .await
1518 .unwrap();
1519 let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8]));
1520 assert_eq!(expected, res);
1521
1522 let res = index
1524 .search(
1525 &TextQuery::StringContains("ab".to_string()),
1526 &NoOpMetricsCollector,
1527 )
1528 .await
1529 .unwrap();
1530 let expected = SearchResult::at_least(RowAddrTreeMap::new());
1531 assert_eq!(expected, res);
1532
1533 let res = index
1535 .search(
1536 &TextQuery::StringContains("no nos".to_string()),
1537 &NoOpMetricsCollector,
1538 )
1539 .await
1540 .unwrap();
1541 let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([8]));
1542 assert_eq!(expected, res);
1543 }
1544
1545 fn test_data_schema() -> Arc<Schema> {
1546 Arc::new(Schema::new(vec![
1547 Field::new(VALUE_COLUMN_NAME, DataType::Utf8, true),
1548 Field::new(ROW_ID, DataType::UInt64, false),
1549 ]))
1550 }
1551
1552 fn simple_data_with_nulls() -> SendableRecordBatchStream {
1553 let data = StringArray::from_iter(&[Some("cat"), Some("dog"), None, None, Some("cat dog")]);
1554 let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64));
1555 let schema = test_data_schema();
1556 let data =
1557 RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap();
1558 Box::pin(RecordBatchStreamAdapter::new(
1559 schema,
1560 stream::once(std::future::ready(Ok(data))),
1561 ))
1562 }
1563
1564 #[test_log::test(tokio::test)]
1565 async fn test_ngram_nulls() {
1566 let data = simple_data_with_nulls();
1567
1568 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1569
1570 let (index, _tmpdir) = do_train(builder, data).await;
1571 assert_eq!(index.tokens.len(), 3);
1572
1573 let res = index
1574 .search(
1575 &TextQuery::StringContains("cat".to_string()),
1576 &NoOpMetricsCollector,
1577 )
1578 .await
1579 .unwrap();
1580 let expected = SearchResult::at_most(RowAddrTreeMap::from_iter([0, 4]));
1581 assert_eq!(expected, res);
1582
1583 let null_posting_list = get_null_posting_list(&index).await;
1584 assert_eq!(null_posting_list, vec![2, 3]);
1585
1586 }
1588
1589 fn empty_data() -> SendableRecordBatchStream {
1590 Box::pin(RecordBatchStreamAdapter::new(
1591 test_data_schema(),
1592 stream::empty::<lance_core::error::DataFusionResult<RecordBatch>>(),
1593 ))
1594 }
1595
1596 #[test_log::test(tokio::test)]
1597 async fn test_train_empty() {
1598 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1599
1600 let (index, _tmpdir) = do_train(builder, empty_data()).await;
1601 assert_eq!(index.tokens.len(), 0);
1602 }
1603
1604 #[test_log::test(tokio::test)]
1605 async fn test_update_empty() {
1606 let data = simple_data_with_nulls();
1607
1608 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1609 let (index, _tmpdir) = do_train(builder, empty_data()).await;
1610
1611 let new_tmpdir = Arc::new(TempDir::default());
1612 let test_store = Arc::new(LanceIndexStore::new(
1613 Arc::new(ObjectStore::local()),
1614 new_tmpdir.obj_path(),
1615 Arc::new(LanceCache::no_cache()),
1616 ));
1617
1618 index.update(data, test_store.as_ref(), None).await.unwrap();
1619
1620 let index = NGramIndex::from_store(test_store, None, &LanceCache::no_cache())
1621 .await
1622 .unwrap();
1623 assert_eq!(index.tokens.len(), 3);
1624 }
1625
1626 async fn row_ids_in_index(index: &NGramIndex) -> Vec<u64> {
1627 let mut row_ids = HashSet::new();
1628 for row_offset in index.tokens.values() {
1629 let list = index
1630 .list_reader
1631 .ngram_list(*row_offset, &NoOpMetricsCollector)
1632 .await
1633 .unwrap();
1634 row_ids.extend(list.bitmap.iter());
1635 }
1636 row_ids.into_iter().sorted().collect()
1637 }
1638
1639 #[test_log::test(tokio::test)]
1640 async fn test_ngram_index_remap() {
1641 let data = simple_data_with_nulls();
1642 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1643 let (index, _tmpdir) = do_train(builder, data).await;
1644
1645 let row_ids = row_ids_in_index(&index).await;
1646 assert_eq!(row_ids, vec![0, 1, 2, 3, 4]);
1647
1648 let new_tmpdir = Arc::new(TempDir::default());
1649 let test_store = Arc::new(LanceIndexStore::new(
1650 Arc::new(ObjectStore::local()),
1651 new_tmpdir.obj_path(),
1652 Arc::new(LanceCache::no_cache()),
1653 ));
1654
1655 let remapping = HashMap::from([(2, Some(100)), (3, None), (4, Some(101))]);
1656 index.remap(&remapping, test_store.as_ref()).await.unwrap();
1657
1658 let index = NGramIndex::from_store(test_store, None, &LanceCache::no_cache())
1659 .await
1660 .unwrap();
1661 let row_ids = row_ids_in_index(&index).await;
1662 assert_eq!(row_ids, vec![0, 1, 100, 101]);
1663
1664 let null_posting_list = get_null_posting_list(&index).await;
1665 assert_eq!(null_posting_list, vec![100]);
1666 }
1667
1668 #[test_log::test(tokio::test)]
1669 async fn test_ngram_index_merge() {
1670 let data = simple_data_with_nulls();
1671 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions::default()).unwrap();
1672 let (index, _tmpdir) = do_train(builder, data).await;
1673
1674 let data = StringArray::from_iter(&[Some("giraffe"), Some("cat"), None]);
1675 let row_ids = UInt64Array::from_iter_values((0..data.len()).map(|i| i as u64 + 100));
1676 let schema = Arc::new(Schema::new(vec![
1677 Field::new(VALUE_COLUMN_NAME, DataType::Utf8, true),
1678 Field::new(ROW_ID, DataType::UInt64, false),
1679 ]));
1680 let data =
1681 RecordBatch::try_new(schema.clone(), vec![Arc::new(data), Arc::new(row_ids)]).unwrap();
1682 let data = Box::pin(RecordBatchStreamAdapter::new(
1683 schema,
1684 stream::once(std::future::ready(Ok(data))),
1685 ));
1686
1687 let posting_list = get_posting_list_for_trigram(&index, "cat").await;
1688 assert_eq!(posting_list, vec![0, 4]);
1689
1690 let new_tmpdir = Arc::new(TempDir::default());
1691 let test_store = Arc::new(LanceIndexStore::new(
1692 Arc::new(ObjectStore::local()),
1693 new_tmpdir.obj_path(),
1694 Arc::new(LanceCache::no_cache()),
1695 ));
1696
1697 index.update(data, test_store.as_ref(), None).await.unwrap();
1698
1699 let index = NGramIndex::from_store(test_store, None, &LanceCache::no_cache())
1700 .await
1701 .unwrap();
1702 let row_ids = row_ids_in_index(&index).await;
1703 assert_eq!(row_ids, vec![0, 1, 2, 3, 4, 100, 101, 102]);
1704
1705 let posting_list = get_posting_list_for_trigram(&index, "cat").await;
1706 assert_eq!(posting_list, vec![0, 4, 101]);
1707
1708 let posting_list = get_posting_list_for_trigram(&index, "ffe").await;
1709 assert_eq!(posting_list, vec![100]);
1710
1711 let posting_list = get_null_posting_list(&index).await;
1712 assert_eq!(posting_list, vec![2, 3, 102]);
1713 }
1714
1715 #[test_log::test(tokio::test)]
1716 async fn test_ngram_index_with_spill() {
1717 let (data, schema) = lance_datagen::gen_batch()
1718 .col(
1719 VALUE_COLUMN_NAME,
1720 lance_datagen::array::rand_utf8(ByteCount::from(50), false),
1721 )
1722 .col(ROW_ID, lance_datagen::array::step::<UInt64Type>())
1723 .into_reader_stream(RowCount::from(128), BatchCount::from(32));
1724
1725 let data = Box::pin(RecordBatchStreamAdapter::new(
1726 schema,
1727 data.map_err(|arrow_err| DataFusionError::ArrowError(Box::new(arrow_err), None)),
1728 ));
1729
1730 let builder = NGramIndexBuilder::try_new(NGramIndexBuilderOptions {
1731 tokens_per_spill: 100,
1732 })
1733 .unwrap();
1734
1735 let (index, _tmpdir) = do_train(builder, data).await;
1736
1737 assert_eq!(index.tokens.len(), 29012);
1738 }
1739}