1use std::fmt::{Debug, Display};
5use std::sync::Arc;
6use std::{
7 cmp::{min, Reverse},
8 collections::BinaryHeap,
9};
10use std::{collections::HashMap, ops::Range};
11
12use crate::metrics::NoOpMetricsCollector;
13use crate::prefilter::NoFilter;
14use crate::scalar::registry::{TrainingCriteria, TrainingOrdering};
15use arrow::datatypes::{self, Float32Type, Int32Type, UInt64Type};
16use arrow::{
17 array::{
18 AsArray, LargeBinaryBuilder, ListBuilder, StringBuilder, UInt32Builder, UInt64Builder,
19 },
20 buffer::OffsetBuffer,
21};
22use arrow::{buffer::ScalarBuffer, datatypes::UInt32Type};
23use arrow_array::{
24 Array, ArrayRef, BooleanArray, Float32Array, LargeBinaryArray, ListArray, OffsetSizeTrait,
25 RecordBatch, UInt32Array, UInt64Array,
26};
27use arrow_schema::{DataType, Field, Schema, SchemaRef};
28use async_trait::async_trait;
29use datafusion::execution::SendableRecordBatchStream;
30use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
31use datafusion_common::DataFusionError;
32use deepsize::DeepSizeOf;
33use fst::{Automaton, IntoStreamer, Streamer};
34use futures::{stream, FutureExt, StreamExt, TryStreamExt};
35use itertools::Itertools;
36use lance_arrow::{iter_str_array, RecordBatchExt};
37use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
38use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap};
39use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS};
40use lance_core::{
41 container::list::ExpLinkedList,
42 utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
43};
44use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD};
45use roaring::RoaringBitmap;
46use snafu::location;
47use std::sync::LazyLock;
48use tokio::task::spawn_blocking;
49use tracing::{info, instrument};
50
51use super::{
52 builder::{
53 doc_file_path, inverted_list_schema, posting_file_path, token_file_path, ScoredDoc,
54 BLOCK_SIZE,
55 },
56 iter::PlainPostingListIterator,
57 query::*,
58 scorer::{idf, IndexBM25Scorer, Scorer, B, K1},
59};
60use super::{
61 builder::{InnerBuilder, PositionRecorder},
62 encoding::compress_posting_list,
63 iter::CompressedPostingListIterator,
64};
65use super::{encoding::compress_positions, iter::PostingListIterator};
66use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams};
67use crate::frag_reuse::FragReuseIndex;
68use crate::pbold;
69use crate::scalar::inverted::lance_tokenizer::TextTokenizer;
70use crate::scalar::inverted::scorer::MemBM25Scorer;
71use crate::scalar::inverted::tokenizer::lance_tokenizer::LanceTokenizer;
72use crate::scalar::{
73 AnyQuery, BuiltinIndexType, CreatedIndex, IndexReader, IndexStore, MetricsCollector,
74 ScalarIndex, ScalarIndexParams, SearchResult, TokenQuery, UpdateCriteria,
75};
76use crate::Index;
77use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys};
78use std::str::FromStr;
79
80pub const INVERTED_INDEX_VERSION: u32 = 1;
83pub const TOKENS_FILE: &str = "tokens.lance";
84pub const INVERT_LIST_FILE: &str = "invert.lance";
85pub const DOCS_FILE: &str = "docs.lance";
86pub const METADATA_FILE: &str = "metadata.lance";
87
88pub const TOKEN_COL: &str = "_token";
89pub const TOKEN_ID_COL: &str = "_token_id";
90pub const TOKEN_FST_BYTES_COL: &str = "_token_fst_bytes";
91pub const TOKEN_NEXT_ID_COL: &str = "_token_next_id";
92pub const TOKEN_TOTAL_LENGTH_COL: &str = "_token_total_length";
93pub const FREQUENCY_COL: &str = "_frequency";
94pub const POSITION_COL: &str = "_position";
95pub const COMPRESSED_POSITION_COL: &str = "_compressed_position";
96pub const POSTING_COL: &str = "_posting";
97pub const MAX_SCORE_COL: &str = "_max_score";
98pub const LENGTH_COL: &str = "_length";
99pub const BLOCK_MAX_SCORE_COL: &str = "_block_max_score";
100pub const NUM_TOKEN_COL: &str = "_num_tokens";
101pub const SCORE_COL: &str = "_score";
102pub const TOKEN_SET_FORMAT_KEY: &str = "token_set_format";
103
104pub static SCORE_FIELD: LazyLock<Field> =
105 LazyLock::new(|| Field::new(SCORE_COL, DataType::Float32, true));
106pub static FTS_SCHEMA: LazyLock<SchemaRef> =
107 LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone(), SCORE_FIELD.clone()])));
108static ROW_ID_SCHEMA: LazyLock<SchemaRef> =
109 LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()])));
110
111#[derive(Debug)]
112struct PartitionCandidates {
113 tokens_by_position: Vec<String>,
114 candidates: Vec<DocCandidate>,
115}
116
117impl PartitionCandidates {
118 fn empty() -> Self {
119 Self {
120 tokens_by_position: Vec::new(),
121 candidates: Vec::new(),
122 }
123 }
124}
125
126#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)]
127pub enum TokenSetFormat {
128 Arrow,
129 #[default]
130 Fst,
131}
132
133impl Display for TokenSetFormat {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 match self {
136 Self::Arrow => f.write_str("arrow"),
137 Self::Fst => f.write_str("fst"),
138 }
139 }
140}
141
142impl FromStr for TokenSetFormat {
143 type Err = Error;
144
145 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
146 match s.trim() {
147 "" => Ok(Self::Arrow),
148 "arrow" => Ok(Self::Arrow),
149 "fst" => Ok(Self::Fst),
150 other => Err(Error::Index {
151 message: format!("unsupported token set format {}", other),
152 location: location!(),
153 }),
154 }
155 }
156}
157
158impl DeepSizeOf for TokenSetFormat {
159 fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
160 0
161 }
162}
163
164#[derive(Clone)]
165pub struct InvertedIndex {
166 params: InvertedIndexParams,
167 store: Arc<dyn IndexStore>,
168 tokenizer: Box<dyn LanceTokenizer>,
169 token_set_format: TokenSetFormat,
170 pub(crate) partitions: Vec<Arc<InvertedPartition>>,
171}
172
173impl Debug for InvertedIndex {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("InvertedIndex")
176 .field("params", &self.params)
177 .field("token_set_format", &self.token_set_format)
178 .field("partitions", &self.partitions)
179 .finish()
180 }
181}
182
183impl DeepSizeOf for InvertedIndex {
184 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
185 self.partitions.deep_size_of_children(context)
186 }
187}
188
189impl InvertedIndex {
190 fn to_builder(&self) -> InvertedIndexBuilder {
191 self.to_builder_with_offset(None)
192 }
193
194 fn to_builder_with_offset(&self, fragment_mask: Option<u64>) -> InvertedIndexBuilder {
195 if self.is_legacy() {
196 InvertedIndexBuilder::from_existing_index(
198 self.params.clone(),
199 None,
200 Vec::new(),
201 self.token_set_format,
202 fragment_mask,
203 )
204 } else {
205 let partitions = match fragment_mask {
206 Some(fragment_mask) => self
207 .partitions
208 .iter()
209 .filter(|part| part.belongs_to_fragment(fragment_mask))
213 .map(|part| part.id())
214 .collect(),
215 None => self.partitions.iter().map(|part| part.id()).collect(),
216 };
217
218 InvertedIndexBuilder::from_existing_index(
219 self.params.clone(),
220 Some(self.store.clone()),
221 partitions,
222 self.token_set_format,
223 fragment_mask,
224 )
225 }
226 }
227
228 pub fn tokenizer(&self) -> Box<dyn LanceTokenizer> {
229 self.tokenizer.clone()
230 }
231
232 pub fn params(&self) -> &InvertedIndexParams {
233 &self.params
234 }
235
236 pub fn partition_count(&self) -> usize {
238 self.partitions.len()
239 }
240
241 #[instrument(level = "debug", skip_all)]
247 pub async fn bm25_search(
248 &self,
249 tokens: Arc<Tokens>,
250 params: Arc<FtsSearchParams>,
251 operator: Operator,
252 prefilter: Arc<dyn PreFilter>,
253 metrics: Arc<dyn MetricsCollector>,
254 ) -> Result<(Vec<u64>, Vec<f32>)> {
255 let limit = params.limit.unwrap_or(usize::MAX);
256 if limit == 0 {
257 return Ok((Vec::new(), Vec::new()));
258 }
259 let mask = prefilter.mask();
260
261 let mut candidates = BinaryHeap::new();
262 let parts = self
263 .partitions
264 .iter()
265 .map(|part| {
266 let part = part.clone();
267 let tokens = tokens.clone();
268 let params = params.clone();
269 let mask = mask.clone();
270 let metrics = metrics.clone();
271 async move {
272 let postings = part
273 .load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref())
274 .await?;
275 if postings.is_empty() {
276 return Ok(PartitionCandidates::empty());
277 }
278 let mut tokens_by_position = vec![String::new(); postings.len()];
279 for posting in &postings {
280 let idx = posting.term_index() as usize;
281 tokens_by_position[idx] = posting.token().to_owned();
282 }
283 let params = params.clone();
284 let mask = mask.clone();
285 let metrics = metrics.clone();
286 spawn_cpu(move || {
287 let candidates = part.bm25_search(
288 params.as_ref(),
289 operator,
290 mask,
291 postings,
292 metrics.as_ref(),
293 )?;
294 Ok(PartitionCandidates {
295 tokens_by_position,
296 candidates,
297 })
298 })
299 .await
300 }
301 })
302 .collect::<Vec<_>>();
303 let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
304 let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
305 let mut idf_cache: HashMap<String, f32> = HashMap::new();
306 while let Some(res) = parts.try_next().await? {
307 if res.candidates.is_empty() {
308 continue;
309 }
310 let mut idf_by_position = Vec::with_capacity(res.tokens_by_position.len());
311 for token in &res.tokens_by_position {
312 let idf_weight = match idf_cache.get(token) {
313 Some(weight) => *weight,
314 None => {
315 let weight = scorer.query_weight(token);
316 idf_cache.insert(token.clone(), weight);
317 weight
318 }
319 };
320 idf_by_position.push(idf_weight);
321 }
322 for DocCandidate {
323 row_id,
324 freqs,
325 doc_length,
326 } in res.candidates
327 {
328 let mut score = 0.0;
329 for (term_index, freq) in freqs.into_iter() {
330 debug_assert!((term_index as usize) < idf_by_position.len());
331 score +=
332 idf_by_position[term_index as usize] * scorer.doc_weight(freq, doc_length);
333 }
334 if candidates.len() < limit {
335 candidates.push(Reverse(ScoredDoc::new(row_id, score)));
336 } else if candidates.peek().unwrap().0.score.0 < score {
337 candidates.pop();
338 candidates.push(Reverse(ScoredDoc::new(row_id, score)));
339 }
340 }
341 }
342
343 Ok(candidates
344 .into_sorted_vec()
345 .into_iter()
346 .map(|Reverse(doc)| (doc.row_id, doc.score.0))
347 .unzip())
348 }
349
350 async fn load_legacy_index(
351 store: Arc<dyn IndexStore>,
352 frag_reuse_index: Option<Arc<FragReuseIndex>>,
353 index_cache: &LanceCache,
354 ) -> Result<Arc<Self>> {
355 log::warn!("loading legacy FTS index");
356 let tokens_fut = tokio::spawn({
357 let store = store.clone();
358 async move {
359 let token_reader = store.open_index_file(TOKENS_FILE).await?;
360 let tokenizer = token_reader
361 .schema()
362 .metadata
363 .get("tokenizer")
364 .map(|s| serde_json::from_str::<InvertedIndexParams>(s))
365 .transpose()?
366 .unwrap_or_default();
367 let tokens = TokenSet::load(token_reader, TokenSetFormat::Arrow).await?;
368 Result::Ok((tokenizer, tokens))
369 }
370 });
371 let invert_list_fut = tokio::spawn({
372 let store = store.clone();
373 let index_cache_clone = index_cache.clone();
374 async move {
375 let invert_list_reader = store.open_index_file(INVERT_LIST_FILE).await?;
376 let invert_list =
377 PostingListReader::try_new(invert_list_reader, &index_cache_clone).await?;
378 Result::Ok(Arc::new(invert_list))
379 }
380 });
381 let docs_fut = tokio::spawn({
382 let store = store.clone();
383 async move {
384 let docs_reader = store.open_index_file(DOCS_FILE).await?;
385 let docs = DocSet::load(docs_reader, true, frag_reuse_index).await?;
386 Result::Ok(docs)
387 }
388 });
389
390 let (tokenizer_config, tokens) = tokens_fut.await??;
391 let inverted_list = invert_list_fut.await??;
392 let docs = docs_fut.await??;
393
394 let tokenizer = tokenizer_config.build()?;
395
396 Ok(Arc::new(Self {
397 params: tokenizer_config,
398 store: store.clone(),
399 tokenizer,
400 token_set_format: TokenSetFormat::Arrow,
401 partitions: vec![Arc::new(InvertedPartition {
402 id: 0,
403 store,
404 tokens,
405 inverted_list,
406 docs,
407 token_set_format: TokenSetFormat::Arrow,
408 })],
409 }))
410 }
411
412 pub fn is_legacy(&self) -> bool {
413 self.partitions.len() == 1 && self.partitions[0].is_legacy()
414 }
415
416 pub async fn load(
417 store: Arc<dyn IndexStore>,
418 frag_reuse_index: Option<Arc<FragReuseIndex>>,
419 index_cache: &LanceCache,
420 ) -> Result<Arc<Self>>
421 where
422 Self: Sized,
423 {
424 match store.open_index_file(METADATA_FILE).await {
429 Ok(reader) => {
430 let params = reader.schema().metadata.get("params").ok_or(Error::Index {
431 message: "params not found in metadata".to_owned(),
432 location: location!(),
433 })?;
434 let params = serde_json::from_str::<InvertedIndexParams>(params)?;
435 let partitions =
436 reader
437 .schema()
438 .metadata
439 .get("partitions")
440 .ok_or(Error::Index {
441 message: "partitions not found in metadata".to_owned(),
442 location: location!(),
443 })?;
444 let partitions: Vec<u64> = serde_json::from_str(partitions)?;
445 let token_set_format = reader
446 .schema()
447 .metadata
448 .get(TOKEN_SET_FORMAT_KEY)
449 .map(|name| TokenSetFormat::from_str(name))
450 .transpose()?
451 .unwrap_or(TokenSetFormat::Arrow);
452
453 let format = token_set_format;
454 let partitions = partitions.into_iter().map(|id| {
455 let store = store.clone();
456 let frag_reuse_index_clone = frag_reuse_index.clone();
457 let index_cache_for_part =
458 index_cache.with_key_prefix(format!("part-{}", id).as_str());
459 let token_set_format = format;
460 async move {
461 Result::Ok(Arc::new(
462 InvertedPartition::load(
463 store,
464 id,
465 frag_reuse_index_clone,
466 &index_cache_for_part,
467 token_set_format,
468 )
469 .await?,
470 ))
471 }
472 });
473 let partitions = stream::iter(partitions)
474 .buffer_unordered(store.io_parallelism())
475 .try_collect::<Vec<_>>()
476 .await?;
477
478 let tokenizer = params.build()?;
479 Ok(Arc::new(Self {
480 params,
481 store,
482 tokenizer,
483 token_set_format,
484 partitions,
485 }))
486 }
487 Err(_) => {
488 Self::load_legacy_index(store, frag_reuse_index, index_cache).await
490 }
491 }
492 }
493}
494
495#[async_trait]
496impl Index for InvertedIndex {
497 fn as_any(&self) -> &dyn std::any::Any {
498 self
499 }
500
501 fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
502 self
503 }
504
505 fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
506 Err(Error::invalid_input(
507 "inverted index cannot be cast to vector index",
508 location!(),
509 ))
510 }
511
512 fn statistics(&self) -> Result<serde_json::Value> {
513 let num_tokens = self
514 .partitions
515 .iter()
516 .map(|part| part.tokens.len())
517 .sum::<usize>();
518 let num_docs = self
519 .partitions
520 .iter()
521 .map(|part| part.docs.len())
522 .sum::<usize>();
523 Ok(serde_json::json!({
524 "params": self.params,
525 "num_tokens": num_tokens,
526 "num_docs": num_docs,
527 }))
528 }
529
530 async fn prewarm(&self) -> Result<()> {
531 for part in &self.partitions {
532 part.inverted_list.prewarm().await?;
533 }
534 Ok(())
535 }
536
537 fn index_type(&self) -> crate::IndexType {
538 crate::IndexType::Inverted
539 }
540
541 async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
542 unimplemented!()
543 }
544}
545
546impl InvertedIndex {
547 async fn do_search(&self, text: &str) -> Result<RecordBatch> {
549 let params = FtsSearchParams::new();
550 let mut tokenizer = self.tokenizer.clone();
551 let tokens = collect_query_tokens(text, &mut tokenizer, None);
552
553 let (doc_ids, _) = self
554 .bm25_search(
555 Arc::new(tokens),
556 params.into(),
557 Operator::And,
558 Arc::new(NoFilter),
559 Arc::new(NoOpMetricsCollector),
560 )
561 .boxed()
562 .await?;
563
564 Ok(RecordBatch::try_new(
565 ROW_ID_SCHEMA.clone(),
566 vec![Arc::new(UInt64Array::from(doc_ids))],
567 )?)
568 }
569}
570
571#[async_trait]
572impl ScalarIndex for InvertedIndex {
573 #[instrument(level = "debug", skip_all)]
575 async fn search(
576 &self,
577 query: &dyn AnyQuery,
578 _metrics: &dyn MetricsCollector,
579 ) -> Result<SearchResult> {
580 let query = query.as_any().downcast_ref::<TokenQuery>().unwrap();
581
582 match query {
583 TokenQuery::TokensContains(text) => {
584 let records = self.do_search(text).await?;
585 let row_ids = records
586 .column(0)
587 .as_any()
588 .downcast_ref::<UInt64Array>()
589 .unwrap();
590 let row_ids = row_ids.iter().flatten().collect_vec();
591 Ok(SearchResult::at_most(RowAddrTreeMap::from_iter(row_ids)))
592 }
593 }
594 }
595
596 fn can_remap(&self) -> bool {
597 true
598 }
599
600 async fn remap(
601 &self,
602 mapping: &HashMap<u64, Option<u64>>,
603 dest_store: &dyn IndexStore,
604 ) -> Result<CreatedIndex> {
605 self.to_builder()
606 .remap(mapping, self.store.clone(), dest_store)
607 .await?;
608
609 let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
610
611 let index_version = match self.token_set_format {
613 TokenSetFormat::Arrow => 0,
614 TokenSetFormat::Fst => INVERTED_INDEX_VERSION,
615 };
616
617 Ok(CreatedIndex {
618 index_details: prost_types::Any::from_msg(&details).unwrap(),
619 index_version,
620 })
621 }
622
623 async fn update(
624 &self,
625 new_data: SendableRecordBatchStream,
626 dest_store: &dyn IndexStore,
627 ) -> Result<CreatedIndex> {
628 self.to_builder().update(new_data, dest_store).await?;
629
630 let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
631
632 let index_version = match self.token_set_format {
634 TokenSetFormat::Arrow => 0,
635 TokenSetFormat::Fst => INVERTED_INDEX_VERSION,
636 };
637
638 Ok(CreatedIndex {
639 index_details: prost_types::Any::from_msg(&details).unwrap(),
640 index_version,
641 })
642 }
643
644 fn update_criteria(&self) -> UpdateCriteria {
645 let criteria = TrainingCriteria::new(TrainingOrdering::None).with_row_id();
646 if self.is_legacy() {
647 UpdateCriteria::requires_old_data(criteria)
648 } else {
649 UpdateCriteria::only_new_data(criteria)
650 }
651 }
652
653 fn derive_index_params(&self) -> Result<ScalarIndexParams> {
654 let mut params = self.params.clone();
655 if params.base_tokenizer.is_empty() {
656 params.base_tokenizer = "simple".to_string();
657 }
658
659 let params_json = serde_json::to_string(¶ms)?;
660
661 Ok(ScalarIndexParams {
662 index_type: BuiltinIndexType::Inverted.as_str().to_string(),
663 params: Some(params_json),
664 })
665 }
666}
667
668#[derive(Debug, Clone, DeepSizeOf)]
669pub struct InvertedPartition {
670 id: u64,
672 store: Arc<dyn IndexStore>,
673 pub(crate) tokens: TokenSet,
674 pub(crate) inverted_list: Arc<PostingListReader>,
675 pub(crate) docs: DocSet,
676 token_set_format: TokenSetFormat,
677}
678
679impl InvertedPartition {
680 pub fn belongs_to_fragment(&self, fragment_mask: u64) -> bool {
691 (self.id() & fragment_mask) == fragment_mask
692 }
693
694 pub fn id(&self) -> u64 {
695 self.id
696 }
697
698 pub fn store(&self) -> &dyn IndexStore {
699 self.store.as_ref()
700 }
701
702 pub fn is_legacy(&self) -> bool {
703 self.inverted_list.lengths.is_none()
704 }
705
706 pub async fn load(
707 store: Arc<dyn IndexStore>,
708 id: u64,
709 frag_reuse_index: Option<Arc<FragReuseIndex>>,
710 index_cache: &LanceCache,
711 token_set_format: TokenSetFormat,
712 ) -> Result<Self> {
713 let token_file = store.open_index_file(&token_file_path(id)).await?;
714 let tokens = TokenSet::load(token_file, token_set_format).await?;
715 let invert_list_file = store.open_index_file(&posting_file_path(id)).await?;
716 let inverted_list = PostingListReader::try_new(invert_list_file, index_cache).await?;
717 let docs_file = store.open_index_file(&doc_file_path(id)).await?;
718 let docs = DocSet::load(docs_file, false, frag_reuse_index).await?;
719
720 Ok(Self {
721 id,
722 store,
723 tokens,
724 inverted_list: Arc::new(inverted_list),
725 docs,
726 token_set_format,
727 })
728 }
729
730 fn map(&self, token: &str) -> Option<u32> {
731 self.tokens.get(token)
732 }
733
734 pub fn expand_fuzzy(&self, tokens: &Tokens, params: &FtsSearchParams) -> Result<Tokens> {
735 let mut new_tokens = Vec::with_capacity(min(tokens.len(), params.max_expansions));
736 for token in tokens {
737 let fuzziness = match params.fuzziness {
738 Some(fuzziness) => fuzziness,
739 None => MatchQuery::auto_fuzziness(token),
740 };
741 let lev =
742 fst::automaton::Levenshtein::new(token, fuzziness).map_err(|e| Error::Index {
743 message: format!("failed to construct the fuzzy query: {}", e),
744 location: location!(),
745 })?;
746
747 let base_len = tokens.token_type().prefix_len(token) as u32;
748 if let TokenMap::Fst(ref map) = self.tokens.tokens {
749 match base_len + params.prefix_length {
750 0 => take_fst_keys(map.search(lev), &mut new_tokens, params.max_expansions),
751 prefix_length => {
752 let prefix = &token[..min(prefix_length as usize, token.len())];
753 let prefix = fst::automaton::Str::new(prefix).starts_with();
754 take_fst_keys(
755 map.search(lev.intersection(prefix)),
756 &mut new_tokens,
757 params.max_expansions,
758 )
759 }
760 }
761 } else {
762 return Err(Error::Index {
763 message: "tokens is not fst, which is not expected".to_owned(),
764 location: location!(),
765 });
766 }
767 }
768 Ok(Tokens::new(new_tokens, tokens.token_type().clone()))
769 }
770
771 #[instrument(level = "debug", skip_all)]
775 pub async fn load_posting_lists(
776 &self,
777 tokens: &Tokens,
778 params: &FtsSearchParams,
779 metrics: &dyn MetricsCollector,
780 ) -> Result<Vec<PostingIterator>> {
781 let is_fuzzy = matches!(params.fuzziness, Some(n) if n != 0);
782 let is_phrase_query = params.phrase_slop.is_some();
783 let tokens = match is_fuzzy {
784 true => self.expand_fuzzy(tokens, params)?,
785 false => tokens.clone(),
786 };
787 let mut token_ids = Vec::with_capacity(tokens.len());
788 for token in tokens {
789 let token_id = self.map(&token);
790 if let Some(token_id) = token_id {
791 token_ids.push((token_id, token));
792 } else if is_phrase_query {
793 return Ok(Vec::new());
795 }
796 }
797 if token_ids.is_empty() {
798 return Ok(Vec::new());
799 }
800 if !is_phrase_query {
801 token_ids.sort_unstable_by_key(|(token_id, _)| *token_id);
803 token_ids.dedup_by_key(|(token_id, _)| *token_id);
804 }
805
806 let num_docs = self.docs.len();
807 stream::iter(token_ids)
808 .enumerate()
809 .map(|(position, (token_id, token))| async move {
810 let posting = self
811 .inverted_list
812 .posting_list(token_id, is_phrase_query, metrics)
813 .await?;
814
815 Result::Ok(PostingIterator::new(
816 token,
817 token_id,
818 position as u32,
819 posting,
820 num_docs,
821 ))
822 })
823 .buffered(self.store.io_parallelism())
824 .try_collect::<Vec<_>>()
825 .await
826 }
827
828 #[instrument(level = "debug", skip_all)]
829 pub fn bm25_search(
830 &self,
831 params: &FtsSearchParams,
832 operator: Operator,
833 mask: Arc<RowAddrMask>,
834 postings: Vec<PostingIterator>,
835 metrics: &dyn MetricsCollector,
836 ) -> Result<Vec<DocCandidate>> {
837 if postings.is_empty() {
838 return Ok(Vec::new());
839 }
840
841 let scorer = IndexBM25Scorer::new(std::iter::once(self));
843 let mut wand = Wand::new(operator, postings.into_iter(), &self.docs, scorer);
844 let hits = wand.search(params, mask, metrics)?;
845 Ok(hits)
847 }
848
849 pub async fn into_builder(self) -> Result<InnerBuilder> {
850 let mut builder = InnerBuilder::new(
851 self.id,
852 self.inverted_list.has_positions(),
853 self.token_set_format,
854 );
855 builder.tokens = self.tokens;
856 builder.docs = self.docs;
857
858 builder
859 .posting_lists
860 .reserve_exact(self.inverted_list.len());
861 for posting_list in self
862 .inverted_list
863 .read_all(self.inverted_list.has_positions())
864 .await?
865 {
866 let posting_list = posting_list?;
867 builder
868 .posting_lists
869 .push(posting_list.into_builder(&builder.docs));
870 }
871 Ok(builder)
872 }
873}
874
875#[derive(Debug, Clone)]
878pub enum TokenMap {
879 HashMap(HashMap<String, u32>),
880 Fst(fst::Map<Vec<u8>>),
881}
882
883impl Default for TokenMap {
884 fn default() -> Self {
885 Self::HashMap(HashMap::new())
886 }
887}
888
889impl DeepSizeOf for TokenMap {
890 fn deep_size_of_children(&self, ctx: &mut deepsize::Context) -> usize {
891 match self {
892 Self::HashMap(map) => map.deep_size_of_children(ctx),
893 Self::Fst(map) => map.as_fst().size(),
894 }
895 }
896}
897
898impl TokenMap {
899 pub fn len(&self) -> usize {
900 match self {
901 Self::HashMap(map) => map.len(),
902 Self::Fst(map) => map.len(),
903 }
904 }
905
906 pub fn is_empty(&self) -> bool {
907 self.len() == 0
908 }
909}
910
911#[derive(Debug, Clone, Default, DeepSizeOf)]
913pub struct TokenSet {
914 pub(crate) tokens: TokenMap,
916 pub(crate) next_id: u32,
917 total_length: usize,
918}
919
920impl TokenSet {
921 pub fn into_mut(self) -> Self {
922 let tokens = match self.tokens {
923 TokenMap::HashMap(map) => map,
924 TokenMap::Fst(map) => {
925 let mut new_map = HashMap::with_capacity(map.len());
926 let mut stream = map.into_stream();
927 while let Some((token, token_id)) = stream.next() {
928 new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
929 }
930
931 new_map
932 }
933 };
934
935 Self {
936 tokens: TokenMap::HashMap(tokens),
937 next_id: self.next_id,
938 total_length: self.total_length,
939 }
940 }
941
942 pub fn len(&self) -> usize {
943 self.tokens.len()
944 }
945
946 pub fn is_empty(&self) -> bool {
947 self.len() == 0
948 }
949
950 pub fn to_batch(self, format: TokenSetFormat) -> Result<RecordBatch> {
951 match format {
952 TokenSetFormat::Arrow => self.into_arrow_batch(),
953 TokenSetFormat::Fst => self.into_fst_batch(),
954 }
955 }
956
957 fn into_arrow_batch(self) -> Result<RecordBatch> {
958 let mut token_builder = StringBuilder::with_capacity(self.tokens.len(), self.total_length);
959 let mut token_id_builder = UInt32Builder::with_capacity(self.tokens.len());
960
961 match self.tokens {
962 TokenMap::Fst(map) => {
963 let mut stream = map.stream();
964 while let Some((token, token_id)) = stream.next() {
965 token_builder.append_value(String::from_utf8_lossy(token));
966 token_id_builder.append_value(token_id as u32);
967 }
968 }
969 TokenMap::HashMap(map) => {
970 for (token, token_id) in map.into_iter().sorted_unstable() {
971 token_builder.append_value(token);
972 token_id_builder.append_value(token_id);
973 }
974 }
975 }
976
977 let token_col = token_builder.finish();
978 let token_id_col = token_id_builder.finish();
979
980 let schema = arrow_schema::Schema::new(vec![
981 arrow_schema::Field::new(TOKEN_COL, DataType::Utf8, false),
982 arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false),
983 ]);
984
985 let batch = RecordBatch::try_new(
986 Arc::new(schema),
987 vec![
988 Arc::new(token_col) as ArrayRef,
989 Arc::new(token_id_col) as ArrayRef,
990 ],
991 )?;
992 Ok(batch)
993 }
994
995 fn into_fst_batch(mut self) -> Result<RecordBatch> {
996 let fst_map = match std::mem::take(&mut self.tokens) {
997 TokenMap::Fst(map) => map,
998 TokenMap::HashMap(map) => Self::build_fst_from_map(map)?,
999 };
1000 let bytes = fst_map.into_fst().into_inner();
1001
1002 let mut fst_builder = LargeBinaryBuilder::with_capacity(1, bytes.len());
1003 fst_builder.append_value(bytes);
1004 let fst_col = fst_builder.finish();
1005
1006 let mut next_id_builder = UInt32Builder::with_capacity(1);
1007 next_id_builder.append_value(self.next_id);
1008 let next_id_col = next_id_builder.finish();
1009
1010 let mut total_length_builder = UInt64Builder::with_capacity(1);
1011 total_length_builder.append_value(self.total_length as u64);
1012 let total_length_col = total_length_builder.finish();
1013
1014 let schema = arrow_schema::Schema::new(vec![
1015 arrow_schema::Field::new(TOKEN_FST_BYTES_COL, DataType::LargeBinary, false),
1016 arrow_schema::Field::new(TOKEN_NEXT_ID_COL, DataType::UInt32, false),
1017 arrow_schema::Field::new(TOKEN_TOTAL_LENGTH_COL, DataType::UInt64, false),
1018 ]);
1019
1020 let batch = RecordBatch::try_new(
1021 Arc::new(schema),
1022 vec![
1023 Arc::new(fst_col) as ArrayRef,
1024 Arc::new(next_id_col) as ArrayRef,
1025 Arc::new(total_length_col) as ArrayRef,
1026 ],
1027 )?;
1028 Ok(batch)
1029 }
1030
1031 fn build_fst_from_map(map: HashMap<String, u32>) -> Result<fst::Map<Vec<u8>>> {
1032 let mut entries: Vec<_> = map.into_iter().collect();
1033 entries.sort_unstable_by(|(lhs, _), (rhs, _)| lhs.cmp(rhs));
1034 let mut builder = fst::MapBuilder::memory();
1035 for (token, token_id) in entries {
1036 builder
1037 .insert(&token, token_id as u64)
1038 .map_err(|e| Error::Index {
1039 message: format!("failed to insert token {}: {}", token, e),
1040 location: location!(),
1041 })?;
1042 }
1043 Ok(builder.into_map())
1044 }
1045
1046 pub async fn load(reader: Arc<dyn IndexReader>, format: TokenSetFormat) -> Result<Self> {
1047 match format {
1048 TokenSetFormat::Arrow => Self::load_arrow(reader).await,
1049 TokenSetFormat::Fst => Self::load_fst(reader).await,
1050 }
1051 }
1052
1053 async fn load_arrow(reader: Arc<dyn IndexReader>) -> Result<Self> {
1054 let batch = reader.read_range(0..reader.num_rows(), None).await?;
1055
1056 let (tokens, next_id, total_length) = spawn_blocking(move || {
1057 let mut next_id = 0;
1058 let mut total_length = 0;
1059 let mut tokens = fst::MapBuilder::memory();
1060
1061 let token_col = batch[TOKEN_COL].as_string::<i32>();
1062 let token_id_col = batch[TOKEN_ID_COL].as_primitive::<datatypes::UInt32Type>();
1063
1064 for (token, &token_id) in token_col.iter().zip(token_id_col.values().iter()) {
1065 let token = token.ok_or(Error::Index {
1066 message: "found null token in token set".to_owned(),
1067 location: location!(),
1068 })?;
1069 next_id = next_id.max(token_id + 1);
1070 total_length += token.len();
1071 tokens
1072 .insert(token, token_id as u64)
1073 .map_err(|e| Error::Index {
1074 message: format!("failed to insert token {}: {}", token, e),
1075 location: location!(),
1076 })?;
1077 }
1078
1079 Ok::<_, Error>((tokens.into_map(), next_id, total_length))
1080 })
1081 .await
1082 .map_err(|err| Error::Execution {
1083 message: format!("failed to spawn blocking task: {}", err),
1084 location: location!(),
1085 })??;
1086
1087 Ok(Self {
1088 tokens: TokenMap::Fst(tokens),
1089 next_id,
1090 total_length,
1091 })
1092 }
1093
1094 async fn load_fst(reader: Arc<dyn IndexReader>) -> Result<Self> {
1095 let batch = reader.read_range(0..reader.num_rows(), None).await?;
1096 if batch.num_rows() == 0 {
1097 return Err(Error::Index {
1098 message: "token set batch is empty".to_owned(),
1099 location: location!(),
1100 });
1101 }
1102
1103 let fst_col = batch[TOKEN_FST_BYTES_COL].as_binary::<i64>();
1104 let bytes = fst_col.value(0);
1105 let map = fst::Map::new(bytes.to_vec()).map_err(|e| Error::Index {
1106 message: format!("failed to load fst tokens: {}", e),
1107 location: location!(),
1108 })?;
1109
1110 let next_id_col = batch[TOKEN_NEXT_ID_COL].as_primitive::<datatypes::UInt32Type>();
1111 let total_length_col =
1112 batch[TOKEN_TOTAL_LENGTH_COL].as_primitive::<datatypes::UInt64Type>();
1113
1114 let next_id = next_id_col.values().first().copied().ok_or(Error::Index {
1115 message: "token next id column is empty".to_owned(),
1116 location: location!(),
1117 })?;
1118
1119 let total_length = total_length_col
1120 .values()
1121 .first()
1122 .copied()
1123 .ok_or(Error::Index {
1124 message: "token total length column is empty".to_owned(),
1125 location: location!(),
1126 })?;
1127
1128 Ok(Self {
1129 tokens: TokenMap::Fst(map),
1130 next_id,
1131 total_length: usize::try_from(total_length).map_err(|_| Error::Index {
1132 message: format!("token total length {} overflows usize", total_length),
1133 location: location!(),
1134 })?,
1135 })
1136 }
1137
1138 pub fn add(&mut self, token: String) -> u32 {
1139 let next_id = self.next_id();
1140 let len = token.len();
1141 let token_id = match self.tokens {
1142 TokenMap::HashMap(ref mut map) => *map.entry(token).or_insert(next_id),
1143 _ => unreachable!("tokens must be HashMap while indexing"),
1144 };
1145
1146 if token_id == next_id {
1148 self.next_id += 1;
1149 self.total_length += len;
1150 }
1151
1152 token_id
1153 }
1154
1155 pub(crate) fn get_or_add(&mut self, token: &str) -> u32 {
1156 let next_id = self.next_id;
1157 match self.tokens {
1158 TokenMap::HashMap(ref mut map) => {
1159 if let Some(&token_id) = map.get(token) {
1160 return token_id;
1161 }
1162
1163 map.insert(token.to_owned(), next_id);
1164 }
1165 _ => unreachable!("tokens must be HashMap while indexing"),
1166 }
1167
1168 self.next_id += 1;
1169 self.total_length += token.len();
1170 next_id
1171 }
1172
1173 pub fn get(&self, token: &str) -> Option<u32> {
1174 match self.tokens {
1175 TokenMap::HashMap(ref map) => map.get(token).copied(),
1176 TokenMap::Fst(ref map) => map.get(token).map(|id| id as u32),
1177 }
1178 }
1179
1180 pub fn remap(&mut self, removed_token_ids: &[u32]) {
1182 if removed_token_ids.is_empty() {
1183 return;
1184 }
1185
1186 let mut map = match std::mem::take(&mut self.tokens) {
1187 TokenMap::HashMap(map) => map,
1188 TokenMap::Fst(map) => {
1189 let mut new_map = HashMap::with_capacity(map.len());
1190 let mut stream = map.into_stream();
1191 while let Some((token, token_id)) = stream.next() {
1192 new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
1193 }
1194
1195 new_map
1196 }
1197 };
1198
1199 map.retain(
1200 |_, token_id| match removed_token_ids.binary_search(token_id) {
1201 Ok(_) => false,
1202 Err(index) => {
1203 *token_id -= index as u32;
1204 true
1205 }
1206 },
1207 );
1208
1209 self.tokens = TokenMap::HashMap(map);
1210 }
1211
1212 pub fn next_id(&self) -> u32 {
1213 self.next_id
1214 }
1215}
1216
1217pub struct PostingListReader {
1218 reader: Arc<dyn IndexReader>,
1219
1220 offsets: Option<Vec<usize>>,
1222
1223 max_scores: Option<Vec<f32>>,
1226
1227 lengths: Option<Vec<u32>>,
1229
1230 has_position: bool,
1231
1232 index_cache: WeakLanceCache,
1233}
1234
1235impl std::fmt::Debug for PostingListReader {
1236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1237 f.debug_struct("InvertedListReader")
1238 .field("offsets", &self.offsets)
1239 .field("max_scores", &self.max_scores)
1240 .finish()
1241 }
1242}
1243
1244impl DeepSizeOf for PostingListReader {
1245 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
1246 self.offsets.deep_size_of_children(context)
1247 + self.max_scores.deep_size_of_children(context)
1248 + self.lengths.deep_size_of_children(context)
1249 }
1250}
1251
1252impl PostingListReader {
1253 pub(crate) async fn try_new(
1254 reader: Arc<dyn IndexReader>,
1255 index_cache: &LanceCache,
1256 ) -> Result<Self> {
1257 let has_position = reader.schema().field(POSITION_COL).is_some();
1258 let (offsets, max_scores, lengths) = if reader.schema().field(POSTING_COL).is_none() {
1259 let (offsets, max_scores) = Self::load_metadata(reader.schema())?;
1260 (Some(offsets), max_scores, None)
1261 } else {
1262 let metadata = reader
1263 .read_range(0..reader.num_rows(), Some(&[MAX_SCORE_COL, LENGTH_COL]))
1264 .await?;
1265 let max_scores = metadata[MAX_SCORE_COL]
1266 .as_primitive::<Float32Type>()
1267 .values()
1268 .to_vec();
1269 let lengths = metadata[LENGTH_COL]
1270 .as_primitive::<UInt32Type>()
1271 .values()
1272 .to_vec();
1273 (None, Some(max_scores), Some(lengths))
1274 };
1275
1276 Ok(Self {
1277 reader,
1278 offsets,
1279 max_scores,
1280 lengths,
1281 has_position,
1282 index_cache: WeakLanceCache::from(index_cache),
1283 })
1284 }
1285
1286 fn load_metadata(
1289 schema: &lance_core::datatypes::Schema,
1290 ) -> Result<(Vec<usize>, Option<Vec<f32>>)> {
1291 let offsets = schema.metadata.get("offsets").ok_or(Error::Index {
1292 message: "offsets not found in metadata".to_owned(),
1293 location: location!(),
1294 })?;
1295 let offsets = serde_json::from_str(offsets)?;
1296
1297 let max_scores = schema
1298 .metadata
1299 .get("max_scores")
1300 .map(|max_scores| serde_json::from_str(max_scores))
1301 .transpose()?;
1302 Ok((offsets, max_scores))
1303 }
1304
1305 pub fn len(&self) -> usize {
1307 match self.offsets {
1308 Some(ref offsets) => offsets.len(),
1309 None => self.reader.num_rows(),
1310 }
1311 }
1312
1313 pub fn is_empty(&self) -> bool {
1314 self.len() == 0
1315 }
1316
1317 pub(crate) fn has_positions(&self) -> bool {
1318 self.has_position
1319 }
1320
1321 pub(crate) fn posting_len(&self, token_id: u32) -> usize {
1322 let token_id = token_id as usize;
1323
1324 match self.offsets {
1325 Some(ref offsets) => {
1326 let next_offset = offsets
1327 .get(token_id + 1)
1328 .copied()
1329 .unwrap_or(self.reader.num_rows());
1330 next_offset - offsets[token_id]
1331 }
1332 None => {
1333 if let Some(lengths) = &self.lengths {
1334 lengths[token_id] as usize
1335 } else {
1336 panic!("posting list reader is not initialized")
1337 }
1338 }
1339 }
1340 }
1341
1342 pub(crate) async fn posting_batch(
1343 &self,
1344 token_id: u32,
1345 with_position: bool,
1346 ) -> Result<RecordBatch> {
1347 if self.offsets.is_some() {
1348 self.posting_batch_legacy(token_id, with_position).await
1349 } else {
1350 let token_id = token_id as usize;
1351 let columns = if with_position {
1352 vec![POSTING_COL, POSITION_COL]
1353 } else {
1354 vec![POSTING_COL]
1355 };
1356 let batch = self
1357 .reader
1358 .read_range(token_id..token_id + 1, Some(&columns))
1359 .await?;
1360 Ok(batch)
1361 }
1362 }
1363
1364 async fn posting_batch_legacy(
1365 &self,
1366 token_id: u32,
1367 with_position: bool,
1368 ) -> Result<RecordBatch> {
1369 let mut columns = vec![ROW_ID, FREQUENCY_COL];
1370 if with_position {
1371 columns.push(POSITION_COL);
1372 }
1373
1374 let length = self.posting_len(token_id);
1375 let token_id = token_id as usize;
1376 let offset = self.offsets.as_ref().unwrap()[token_id];
1377 let batch = self
1378 .reader
1379 .read_range(offset..offset + length, Some(&columns))
1380 .await?;
1381 Ok(batch)
1382 }
1383
1384 #[instrument(level = "debug", skip(self, metrics))]
1385 pub(crate) async fn posting_list(
1386 &self,
1387 token_id: u32,
1388 is_phrase_query: bool,
1389 metrics: &dyn MetricsCollector,
1390 ) -> Result<PostingList> {
1391 let cache_key = PostingListKey { token_id };
1392 let mut posting = self
1393 .index_cache
1394 .get_or_insert_with_key(cache_key, || async move {
1395 metrics.record_part_load();
1396 info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="inverted", part_id=token_id);
1397 let batch = self.posting_batch(token_id, false).await?;
1398 self.posting_list_from_batch(&batch, token_id)
1399 })
1400 .await?
1401 .as_ref()
1402 .clone();
1403
1404 if is_phrase_query {
1405 let positions = self.read_positions(token_id).await?;
1407 posting.set_positions(positions);
1408 }
1409
1410 Ok(posting)
1411 }
1412
1413 pub(crate) fn posting_list_from_batch(
1414 &self,
1415 batch: &RecordBatch,
1416 token_id: u32,
1417 ) -> Result<PostingList> {
1418 let posting_list = PostingList::from_batch(
1419 batch,
1420 self.max_scores
1421 .as_ref()
1422 .map(|max_scores| max_scores[token_id as usize]),
1423 self.lengths
1424 .as_ref()
1425 .map(|lengths| lengths[token_id as usize]),
1426 )?;
1427 Ok(posting_list)
1428 }
1429
1430 async fn prewarm(&self) -> Result<()> {
1431 let batch = self.read_batch(false).await?;
1432 for token_id in 0..self.len() {
1433 let posting_range = self.posting_list_range(token_id as u32);
1434 let batch = batch.slice(posting_range.start, posting_range.end - posting_range.start);
1435 let batch = batch.shrink_to_fit()?;
1438 let posting_list = self.posting_list_from_batch(&batch, token_id as u32)?;
1439 let inserted = self
1440 .index_cache
1441 .insert_with_key(
1442 &PostingListKey {
1443 token_id: token_id as u32,
1444 },
1445 Arc::new(posting_list),
1446 )
1447 .await;
1448
1449 if !inserted {
1450 return Err(Error::Internal {
1451 message: "Failed to prewarm index: cache is no longer available".to_string(),
1452 location: location!(),
1453 });
1454 }
1455 }
1456
1457 Ok(())
1458 }
1459
1460 pub(crate) async fn read_batch(&self, with_position: bool) -> Result<RecordBatch> {
1461 let columns = self.posting_columns(with_position);
1462 let batch = self
1463 .reader
1464 .read_range(0..self.reader.num_rows(), Some(&columns))
1465 .await?;
1466 Ok(batch)
1467 }
1468
1469 pub(crate) async fn read_all(
1470 &self,
1471 with_position: bool,
1472 ) -> Result<impl Iterator<Item = Result<PostingList>> + '_> {
1473 let batch = self.read_batch(with_position).await?;
1474 Ok((0..self.len()).map(move |i| {
1475 let token_id = i as u32;
1476 let range = self.posting_list_range(token_id);
1477 let batch = batch.slice(i, range.end - range.start);
1478 self.posting_list_from_batch(&batch, token_id)
1479 }))
1480 }
1481
1482 async fn read_positions(&self, token_id: u32) -> Result<ListArray> {
1483 let positions = self.index_cache.get_or_insert_with_key(PositionKey { token_id }, || async move {
1484 let batch = self
1485 .reader
1486 .read_range(self.posting_list_range(token_id), Some(&[POSITION_COL]))
1487 .await.map_err(|e| {
1488 match e {
1489 Error::Schema { .. } => Error::invalid_input(
1490 "position is not found but required for phrase queries, try recreating the index with position".to_owned(),
1491 location!(),
1492 ),
1493 e => e
1494 }
1495 })?;
1496 Result::Ok(Positions(batch[POSITION_COL]
1497 .as_list::<i32>()
1498 .clone()))
1499 }).await?;
1500 Ok(positions.0.clone())
1501 }
1502
1503 fn posting_list_range(&self, token_id: u32) -> Range<usize> {
1504 match self.offsets {
1505 Some(ref offsets) => {
1506 let offset = offsets[token_id as usize];
1507 let posting_len = self.posting_len(token_id);
1508 offset..offset + posting_len
1509 }
1510 None => {
1511 let token_id = token_id as usize;
1512 token_id..token_id + 1
1513 }
1514 }
1515 }
1516
1517 fn posting_columns(&self, with_position: bool) -> Vec<&'static str> {
1518 let mut base_columns = match self.offsets {
1519 Some(_) => vec![ROW_ID, FREQUENCY_COL],
1520 None => vec![POSTING_COL],
1521 };
1522 if with_position {
1523 base_columns.push(POSITION_COL);
1524 }
1525 base_columns
1526 }
1527}
1528
1529#[derive(Clone)]
1532pub struct Positions(ListArray);
1533
1534impl DeepSizeOf for Positions {
1535 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1536 self.0.get_buffer_memory_size()
1537 }
1538}
1539
1540#[derive(Debug, Clone)]
1542pub struct PostingListKey {
1543 pub token_id: u32,
1544}
1545
1546impl CacheKey for PostingListKey {
1547 type ValueType = PostingList;
1548
1549 fn key(&self) -> std::borrow::Cow<'_, str> {
1550 format!("postings-{}", self.token_id).into()
1551 }
1552}
1553
1554#[derive(Debug, Clone)]
1555pub struct PositionKey {
1556 pub token_id: u32,
1557}
1558
1559impl CacheKey for PositionKey {
1560 type ValueType = Positions;
1561
1562 fn key(&self) -> std::borrow::Cow<'_, str> {
1563 format!("positions-{}", self.token_id).into()
1564 }
1565}
1566
1567#[derive(Debug, Clone, DeepSizeOf)]
1568pub enum PostingList {
1569 Plain(PlainPostingList),
1570 Compressed(CompressedPostingList),
1571}
1572
1573impl PostingList {
1574 pub fn from_batch(
1575 batch: &RecordBatch,
1576 max_score: Option<f32>,
1577 length: Option<u32>,
1578 ) -> Result<Self> {
1579 match batch.column_by_name(POSTING_COL) {
1580 Some(_) => {
1581 debug_assert!(max_score.is_some() && length.is_some());
1582 let posting =
1583 CompressedPostingList::from_batch(batch, max_score.unwrap(), length.unwrap());
1584 Ok(Self::Compressed(posting))
1585 }
1586 None => {
1587 let posting = PlainPostingList::from_batch(batch, max_score);
1588 Ok(Self::Plain(posting))
1589 }
1590 }
1591 }
1592
1593 pub fn iter(&self) -> PostingListIterator<'_> {
1594 PostingListIterator::new(self)
1595 }
1596
1597 pub fn has_position(&self) -> bool {
1598 match self {
1599 Self::Plain(posting) => posting.positions.is_some(),
1600 Self::Compressed(posting) => posting.positions.is_some(),
1601 }
1602 }
1603
1604 pub fn set_positions(&mut self, positions: ListArray) {
1605 match self {
1606 Self::Plain(posting) => posting.positions = Some(positions),
1607 Self::Compressed(posting) => {
1608 posting.positions = Some(positions.value(0).as_list::<i32>().clone());
1609 }
1610 }
1611 }
1612
1613 pub fn max_score(&self) -> Option<f32> {
1614 match self {
1615 Self::Plain(posting) => posting.max_score,
1616 Self::Compressed(posting) => Some(posting.max_score),
1617 }
1618 }
1619
1620 pub fn len(&self) -> usize {
1621 match self {
1622 Self::Plain(posting) => posting.len(),
1623 Self::Compressed(posting) => posting.length as usize,
1624 }
1625 }
1626
1627 pub fn is_empty(&self) -> bool {
1628 self.len() == 0
1629 }
1630
1631 pub fn into_builder(self, docs: &DocSet) -> PostingListBuilder {
1632 let mut builder = PostingListBuilder::new(self.has_position());
1633 match self {
1634 Self::Plain(posting) => {
1636 struct Item {
1640 doc_id: u32,
1641 positions: PositionRecorder,
1642 }
1643 let doc_ids = docs
1644 .row_ids
1645 .iter()
1646 .enumerate()
1647 .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
1648 .collect::<HashMap<_, _>>();
1649 let mut items = Vec::with_capacity(posting.len());
1650 for (row_id, freq, positions) in posting.iter() {
1651 let freq = freq as u32;
1652 let positions = match positions {
1653 Some(positions) => {
1654 PositionRecorder::Position(positions.collect::<Vec<_>>().into())
1655 }
1656 None => PositionRecorder::Count(freq),
1657 };
1658 items.push(Item {
1659 doc_id: doc_ids[&row_id],
1660 positions,
1661 });
1662 }
1663 items.sort_unstable_by_key(|item| item.doc_id);
1664 for item in items {
1665 builder.add(item.doc_id, item.positions);
1666 }
1667 }
1668 Self::Compressed(posting) => {
1669 posting.iter().for_each(|(doc_id, freq, positions)| {
1670 let positions = match positions {
1671 Some(positions) => {
1672 PositionRecorder::Position(positions.collect::<Vec<_>>().into())
1673 }
1674 None => PositionRecorder::Count(freq),
1675 };
1676 builder.add(doc_id, positions);
1677 });
1678 }
1679 }
1680 builder
1681 }
1682}
1683
1684#[derive(Debug, PartialEq, Clone)]
1685pub struct PlainPostingList {
1686 pub row_ids: ScalarBuffer<u64>,
1687 pub frequencies: ScalarBuffer<f32>,
1688 pub max_score: Option<f32>,
1689 pub positions: Option<ListArray>, }
1691
1692impl DeepSizeOf for PlainPostingList {
1693 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1694 self.row_ids.len() * std::mem::size_of::<u64>()
1695 + self.frequencies.len() * std::mem::size_of::<u32>()
1696 + self
1697 .positions
1698 .as_ref()
1699 .map(|positions| positions.get_buffer_memory_size())
1700 .unwrap_or(0)
1701 }
1702}
1703
1704impl PlainPostingList {
1705 pub fn new(
1706 row_ids: ScalarBuffer<u64>,
1707 frequencies: ScalarBuffer<f32>,
1708 max_score: Option<f32>,
1709 positions: Option<ListArray>,
1710 ) -> Self {
1711 Self {
1712 row_ids,
1713 frequencies,
1714 max_score,
1715 positions,
1716 }
1717 }
1718
1719 pub fn from_batch(batch: &RecordBatch, max_score: Option<f32>) -> Self {
1720 let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().values().clone();
1721 let frequencies = batch[FREQUENCY_COL]
1722 .as_primitive::<Float32Type>()
1723 .values()
1724 .clone();
1725 let positions = batch
1726 .column_by_name(POSITION_COL)
1727 .map(|col| col.as_list::<i32>().clone());
1728
1729 Self::new(row_ids, frequencies, max_score, positions)
1730 }
1731
1732 pub fn len(&self) -> usize {
1733 self.row_ids.len()
1734 }
1735
1736 pub fn is_empty(&self) -> bool {
1737 self.len() == 0
1738 }
1739
1740 pub fn iter(&self) -> PlainPostingListIterator<'_> {
1741 Box::new(
1742 self.row_ids
1743 .iter()
1744 .zip(self.frequencies.iter())
1745 .enumerate()
1746 .map(|(idx, (doc_id, freq))| {
1747 (
1748 *doc_id,
1749 *freq,
1750 self.positions.as_ref().map(|p| {
1751 let start = p.value_offsets()[idx] as usize;
1752 let end = p.value_offsets()[idx + 1] as usize;
1753 Box::new(
1754 p.values().as_primitive::<Int32Type>().values()[start..end]
1755 .iter()
1756 .map(|pos| *pos as u32),
1757 ) as _
1758 }),
1759 )
1760 }),
1761 )
1762 }
1763
1764 #[inline]
1765 pub fn doc(&self, i: usize) -> LocatedDocInfo {
1766 LocatedDocInfo::new(self.row_ids[i], self.frequencies[i])
1767 }
1768
1769 pub fn positions(&self, index: usize) -> Option<Arc<dyn Array>> {
1770 self.positions
1771 .as_ref()
1772 .map(|positions| positions.value(index))
1773 }
1774
1775 pub fn max_score(&self) -> Option<f32> {
1776 self.max_score
1777 }
1778
1779 pub fn row_id(&self, i: usize) -> u64 {
1780 self.row_ids[i]
1781 }
1782}
1783
1784#[derive(Debug, PartialEq, Clone)]
1785pub struct CompressedPostingList {
1786 pub max_score: f32,
1787 pub length: u32,
1788 pub blocks: LargeBinaryArray,
1791 pub positions: Option<ListArray>,
1792}
1793
1794impl DeepSizeOf for CompressedPostingList {
1795 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1796 self.blocks.get_buffer_memory_size()
1797 + self
1798 .positions
1799 .as_ref()
1800 .map(|positions| positions.get_buffer_memory_size())
1801 .unwrap_or(0)
1802 }
1803}
1804
1805impl CompressedPostingList {
1806 pub fn new(
1807 blocks: LargeBinaryArray,
1808 max_score: f32,
1809 length: u32,
1810 positions: Option<ListArray>,
1811 ) -> Self {
1812 Self {
1813 max_score,
1814 length,
1815 blocks,
1816 positions,
1817 }
1818 }
1819
1820 pub fn from_batch(batch: &RecordBatch, max_score: f32, length: u32) -> Self {
1821 debug_assert_eq!(batch.num_rows(), 1);
1822 let blocks = batch[POSTING_COL]
1823 .as_list::<i32>()
1824 .value(0)
1825 .as_binary::<i64>()
1826 .clone();
1827 let positions = batch
1828 .column_by_name(POSITION_COL)
1829 .map(|col| col.as_list::<i32>().value(0).as_list::<i32>().clone());
1830
1831 Self {
1832 max_score,
1833 length,
1834 blocks,
1835 positions,
1836 }
1837 }
1838
1839 pub fn iter(&self) -> CompressedPostingListIterator {
1840 CompressedPostingListIterator::new(
1841 self.length as usize,
1842 self.blocks.clone(),
1843 self.positions.clone(),
1844 )
1845 }
1846
1847 pub fn block_max_score(&self, block_idx: usize) -> f32 {
1848 let block = self.blocks.value(block_idx);
1849 block[0..4].try_into().map(f32::from_le_bytes).unwrap()
1850 }
1851
1852 pub fn block_least_doc_id(&self, block_idx: usize) -> u32 {
1853 let block = self.blocks.value(block_idx);
1854 block[4..8].try_into().map(u32::from_le_bytes).unwrap()
1855 }
1856}
1857
1858#[derive(Debug)]
1859pub struct PostingListBuilder {
1860 pub doc_ids: ExpLinkedList<u32>,
1861 pub frequencies: ExpLinkedList<u32>,
1862 pub positions: Option<PositionBuilder>,
1863}
1864
1865impl PostingListBuilder {
1866 pub fn size(&self) -> u64 {
1867 (std::mem::size_of::<u32>() * self.doc_ids.len()
1868 + std::mem::size_of::<u32>() * self.frequencies.len()
1869 + self
1870 .positions
1871 .as_ref()
1872 .map(|positions| positions.size())
1873 .unwrap_or(0)) as u64
1874 }
1875
1876 pub fn has_positions(&self) -> bool {
1877 self.positions.is_some()
1878 }
1879
1880 pub fn new(with_position: bool) -> Self {
1881 Self {
1882 doc_ids: ExpLinkedList::new().with_capacity_limit(128),
1883 frequencies: ExpLinkedList::new().with_capacity_limit(128),
1884 positions: with_position.then(PositionBuilder::new),
1885 }
1886 }
1887
1888 pub fn len(&self) -> usize {
1889 self.doc_ids.len()
1890 }
1891
1892 pub fn is_empty(&self) -> bool {
1893 self.len() == 0
1894 }
1895
1896 pub fn iter(&self) -> impl Iterator<Item = (&u32, &u32, Option<&[u32]>)> {
1897 self.doc_ids
1898 .iter()
1899 .zip(self.frequencies.iter())
1900 .enumerate()
1901 .map(|(idx, (doc_id, freq))| {
1902 let positions = self.positions.as_ref().map(|positions| positions.get(idx));
1903 (doc_id, freq, positions)
1904 })
1905 }
1906
1907 pub fn add(&mut self, doc_id: u32, term_positions: PositionRecorder) {
1908 self.doc_ids.push(doc_id);
1909 self.frequencies.push(term_positions.len());
1910 if let Some(positions) = self.positions.as_mut() {
1911 positions.push(term_positions.into_vec());
1912 }
1913 }
1914
1915 pub fn to_batch(self, block_max_scores: Vec<f32>) -> Result<RecordBatch> {
1917 let length = self.len();
1918 let max_score = block_max_scores.iter().copied().fold(f32::MIN, f32::max);
1919
1920 let schema = inverted_list_schema(self.has_positions());
1921 let compressed = compress_posting_list(
1922 self.doc_ids.len(),
1923 self.doc_ids.iter(),
1924 self.frequencies.iter(),
1925 block_max_scores.into_iter(),
1926 )?;
1927 let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, compressed.len() as i32]));
1928 let mut columns = vec![
1929 Arc::new(ListArray::try_new(
1930 Arc::new(Field::new("item", datatypes::DataType::LargeBinary, true)),
1931 offsets,
1932 Arc::new(compressed),
1933 None,
1934 )?) as ArrayRef,
1935 Arc::new(Float32Array::from_iter_values(std::iter::once(max_score))) as ArrayRef,
1936 Arc::new(UInt32Array::from_iter_values(std::iter::once(
1937 self.len() as u32
1938 ))) as ArrayRef,
1939 ];
1940
1941 if let Some(positions) = self.positions.as_ref() {
1942 let mut position_builder = ListBuilder::new(ListBuilder::with_capacity(
1943 LargeBinaryBuilder::new(),
1944 length,
1945 ));
1946 for index in 0..length {
1947 let positions_in_doc = positions.get(index);
1948 let compressed = compress_positions(positions_in_doc)?;
1949 let inner_builder = position_builder.values();
1950 inner_builder.append_value(compressed.into_iter());
1951 }
1952 position_builder.append(true);
1953 let position_col = position_builder.finish();
1954 columns.push(Arc::new(position_col));
1955 }
1956
1957 let batch = RecordBatch::try_new(schema, columns)?;
1958 Ok(batch)
1959 }
1960
1961 pub fn remap(&mut self, removed: &[u32]) {
1962 let mut cursor = 0;
1963 let mut new_doc_ids = ExpLinkedList::with_capacity(self.len());
1964 let mut new_frequencies = ExpLinkedList::with_capacity(self.len());
1965 let mut new_positions = self.positions.as_mut().map(|_| PositionBuilder::new());
1966 for (&doc_id, &freq, positions) in self.iter() {
1967 while cursor < removed.len() && removed[cursor] < doc_id {
1968 cursor += 1;
1969 }
1970 if cursor < removed.len() && removed[cursor] == doc_id {
1971 continue;
1973 }
1974 new_doc_ids.push(doc_id - cursor as u32);
1977 new_frequencies.push(freq);
1978 if let Some(new_positions) = new_positions.as_mut() {
1979 new_positions.push(positions.unwrap().to_vec());
1980 }
1981 }
1982
1983 self.doc_ids = new_doc_ids;
1984 self.frequencies = new_frequencies;
1985 self.positions = new_positions;
1986 }
1987}
1988
1989#[derive(Debug, Clone, DeepSizeOf)]
1990pub struct PositionBuilder {
1991 positions: Vec<u32>,
1992 offsets: Vec<i32>,
1993}
1994
1995impl Default for PositionBuilder {
1996 fn default() -> Self {
1997 Self::new()
1998 }
1999}
2000
2001impl PositionBuilder {
2002 pub fn new() -> Self {
2003 Self {
2004 positions: Vec::new(),
2005 offsets: vec![0],
2006 }
2007 }
2008
2009 pub fn size(&self) -> usize {
2010 std::mem::size_of::<u32>() * self.positions.len()
2011 + std::mem::size_of::<i32>() * self.offsets.len()
2012 }
2013
2014 pub fn total_len(&self) -> usize {
2015 self.positions.len()
2016 }
2017
2018 pub fn push(&mut self, positions: Vec<u32>) {
2019 self.positions.extend(positions);
2020 self.offsets.push(self.positions.len() as i32);
2021 }
2022
2023 pub fn get(&self, i: usize) -> &[u32] {
2024 let start = self.offsets[i] as usize;
2025 let end = self.offsets[i + 1] as usize;
2026 &self.positions[start..end]
2027 }
2028}
2029
2030impl From<Vec<Vec<u32>>> for PositionBuilder {
2031 fn from(positions: Vec<Vec<u32>>) -> Self {
2032 let mut builder = Self::new();
2033 builder.offsets.reserve(positions.len());
2034 for pos in positions {
2035 builder.push(pos);
2036 }
2037 builder
2038 }
2039}
2040
2041#[derive(Debug, Clone, DeepSizeOf, Copy)]
2042pub enum DocInfo {
2043 Located(LocatedDocInfo),
2044 Raw(RawDocInfo),
2045}
2046
2047impl DocInfo {
2048 pub fn doc_id(&self) -> u64 {
2049 match self {
2050 Self::Raw(info) => info.doc_id as u64,
2051 Self::Located(info) => info.row_id,
2052 }
2053 }
2054
2055 pub fn frequency(&self) -> u32 {
2056 match self {
2057 Self::Raw(info) => info.frequency,
2058 Self::Located(info) => info.frequency as u32,
2059 }
2060 }
2061}
2062
2063impl Eq for DocInfo {}
2064
2065impl PartialEq for DocInfo {
2066 fn eq(&self, other: &Self) -> bool {
2067 self.doc_id() == other.doc_id()
2068 }
2069}
2070
2071impl PartialOrd for DocInfo {
2072 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2073 Some(self.cmp(other))
2074 }
2075}
2076
2077impl Ord for DocInfo {
2078 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2079 self.doc_id().cmp(&other.doc_id())
2080 }
2081}
2082
2083#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
2084pub struct LocatedDocInfo {
2085 pub row_id: u64,
2086 pub frequency: f32,
2087}
2088
2089impl LocatedDocInfo {
2090 pub fn new(row_id: u64, frequency: f32) -> Self {
2091 Self { row_id, frequency }
2092 }
2093}
2094
2095impl Eq for LocatedDocInfo {}
2096
2097impl PartialEq for LocatedDocInfo {
2098 fn eq(&self, other: &Self) -> bool {
2099 self.row_id == other.row_id
2100 }
2101}
2102
2103impl PartialOrd for LocatedDocInfo {
2104 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2105 Some(self.cmp(other))
2106 }
2107}
2108
2109impl Ord for LocatedDocInfo {
2110 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2111 self.row_id.cmp(&other.row_id)
2112 }
2113}
2114
2115#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
2116pub struct RawDocInfo {
2117 pub doc_id: u32,
2118 pub frequency: u32,
2119}
2120
2121impl RawDocInfo {
2122 pub fn new(doc_id: u32, frequency: u32) -> Self {
2123 Self { doc_id, frequency }
2124 }
2125}
2126
2127impl Eq for RawDocInfo {}
2128
2129impl PartialEq for RawDocInfo {
2130 fn eq(&self, other: &Self) -> bool {
2131 self.doc_id == other.doc_id
2132 }
2133}
2134
2135impl PartialOrd for RawDocInfo {
2136 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2137 Some(self.cmp(other))
2138 }
2139}
2140
2141impl Ord for RawDocInfo {
2142 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2143 self.doc_id.cmp(&other.doc_id)
2144 }
2145}
2146
2147#[derive(Debug, Clone, Default, DeepSizeOf)]
2150pub struct DocSet {
2151 row_ids: Vec<u64>,
2152 num_tokens: Vec<u32>,
2153 inv: Vec<(u64, u32)>,
2155
2156 total_tokens: u64,
2157}
2158
2159impl DocSet {
2160 #[inline]
2161 pub fn len(&self) -> usize {
2162 self.row_ids.len()
2163 }
2164
2165 pub fn is_empty(&self) -> bool {
2166 self.len() == 0
2167 }
2168
2169 pub fn iter(&self) -> impl Iterator<Item = (&u64, &u32)> {
2170 self.row_ids.iter().zip(self.num_tokens.iter())
2171 }
2172
2173 pub fn row_id(&self, doc_id: u32) -> u64 {
2174 self.row_ids[doc_id as usize]
2175 }
2176
2177 pub fn doc_id(&self, row_id: u64) -> Option<u64> {
2178 if self.inv.is_empty() {
2179 match self.row_ids.binary_search(&row_id) {
2181 Ok(_) => Some(row_id),
2182 Err(_) => None,
2183 }
2184 } else {
2185 match self.inv.binary_search_by_key(&row_id, |x| x.0) {
2186 Ok(idx) => Some(self.inv[idx].1 as u64),
2187 Err(_) => None,
2188 }
2189 }
2190 }
2191 pub fn total_tokens_num(&self) -> u64 {
2192 self.total_tokens
2193 }
2194
2195 #[inline]
2196 pub fn average_length(&self) -> f32 {
2197 self.total_tokens as f32 / self.len() as f32
2198 }
2199
2200 pub fn calculate_block_max_scores<'a>(
2201 &self,
2202 doc_ids: impl Iterator<Item = &'a u32>,
2203 freqs: impl Iterator<Item = &'a u32>,
2204 ) -> Vec<f32> {
2205 let avgdl = self.average_length();
2206 let length = doc_ids.size_hint().0;
2207 let num_blocks = length.div_ceil(BLOCK_SIZE);
2208 let mut block_max_scores = Vec::with_capacity(num_blocks);
2209 let idf_scale = idf(length, self.len()) * (K1 + 1.0);
2210 let mut max_score = f32::MIN;
2211 for (i, (doc_id, freq)) in doc_ids.zip(freqs).enumerate() {
2212 let doc_norm = K1 * (1.0 - B + B * self.num_tokens(*doc_id) as f32 / avgdl);
2213 let freq = *freq as f32;
2214 let score = freq / (freq + doc_norm);
2215 if score > max_score {
2216 max_score = score;
2217 }
2218 if (i + 1) % BLOCK_SIZE == 0 {
2219 max_score *= idf_scale;
2220 block_max_scores.push(max_score);
2221 max_score = f32::MIN;
2222 }
2223 }
2224 if length % BLOCK_SIZE > 0 {
2225 max_score *= idf_scale;
2226 block_max_scores.push(max_score);
2227 }
2228 block_max_scores
2229 }
2230
2231 pub fn to_batch(&self) -> Result<RecordBatch> {
2232 let row_id_col = UInt64Array::from_iter_values(self.row_ids.iter().cloned());
2233 let num_tokens_col = UInt32Array::from_iter_values(self.num_tokens.iter().cloned());
2234
2235 let schema = arrow_schema::Schema::new(vec![
2236 arrow_schema::Field::new(ROW_ID, DataType::UInt64, false),
2237 arrow_schema::Field::new(NUM_TOKEN_COL, DataType::UInt32, false),
2238 ]);
2239
2240 let batch = RecordBatch::try_new(
2241 Arc::new(schema),
2242 vec![
2243 Arc::new(row_id_col) as ArrayRef,
2244 Arc::new(num_tokens_col) as ArrayRef,
2245 ],
2246 )?;
2247 Ok(batch)
2248 }
2249
2250 pub async fn load(
2251 reader: Arc<dyn IndexReader>,
2252 is_legacy: bool,
2253 frag_reuse_index: Option<Arc<FragReuseIndex>>,
2254 ) -> Result<Self> {
2255 let batch = reader.read_range(0..reader.num_rows(), None).await?;
2256 let row_id_col = batch[ROW_ID].as_primitive::<datatypes::UInt64Type>();
2257 let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::<datatypes::UInt32Type>();
2258
2259 if is_legacy {
2261 let (row_ids, num_tokens): (Vec<_>, Vec<_>) = row_id_col
2262 .values()
2263 .iter()
2264 .filter_map(|id| {
2265 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
2266 frag_reuse_index_ref.remap_row_id(*id)
2267 } else {
2268 Some(*id)
2269 }
2270 })
2271 .zip(num_tokens_col.values().iter())
2272 .sorted_unstable_by_key(|x| x.0)
2273 .unzip();
2274
2275 let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2276 return Ok(Self {
2277 row_ids,
2278 num_tokens,
2279 inv: Vec::new(),
2280 total_tokens,
2281 });
2282 }
2283
2284 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
2287 let mut row_ids = Vec::with_capacity(row_id_col.len());
2288 let mut num_tokens = Vec::with_capacity(num_tokens_col.len());
2289 for (row_id, num_token) in row_id_col.values().iter().zip(num_tokens_col.values()) {
2290 if let Some(new_row_id) = frag_reuse_index_ref.remap_row_id(*row_id) {
2291 row_ids.push(new_row_id);
2292 num_tokens.push(*num_token);
2293 }
2294 }
2295
2296 let mut inv: Vec<(u64, u32)> = row_ids
2297 .iter()
2298 .enumerate()
2299 .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
2300 .collect();
2301 inv.sort_unstable_by_key(|entry| entry.0);
2302
2303 let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2304 return Ok(Self {
2305 row_ids,
2306 num_tokens,
2307 inv,
2308 total_tokens,
2309 });
2310 }
2311
2312 let row_ids = row_id_col.values().to_vec();
2313 let num_tokens = num_tokens_col.values().to_vec();
2314 let mut inv: Vec<(u64, u32)> = row_ids
2315 .iter()
2316 .enumerate()
2317 .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
2318 .collect();
2319 if !row_ids.is_sorted() {
2320 inv.sort_unstable_by_key(|entry| entry.0);
2321 }
2322 let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2323 Ok(Self {
2324 row_ids,
2325 num_tokens,
2326 inv,
2327 total_tokens,
2328 })
2329 }
2330
2331 pub fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Vec<u32> {
2334 let mut removed = Vec::new();
2335 let len = self.len();
2336 let row_ids = std::mem::replace(&mut self.row_ids, Vec::with_capacity(len));
2337 let num_tokens = std::mem::replace(&mut self.num_tokens, Vec::with_capacity(len));
2338 for (doc_id, (row_id, num_token)) in std::iter::zip(row_ids, num_tokens).enumerate() {
2339 match mapping.get(&row_id) {
2340 Some(Some(new_row_id)) => {
2341 self.row_ids.push(*new_row_id);
2342 self.num_tokens.push(num_token);
2343 }
2344 Some(None) => {
2345 removed.push(doc_id as u32);
2346 }
2347 None => {
2348 self.row_ids.push(row_id);
2349 self.num_tokens.push(num_token);
2350 }
2351 }
2352 }
2353 removed
2354 }
2355
2356 #[inline]
2357 pub fn num_tokens(&self, doc_id: u32) -> u32 {
2358 self.num_tokens[doc_id as usize]
2359 }
2360
2361 #[inline]
2364 pub fn num_tokens_by_row_id(&self, row_id: u64) -> u32 {
2365 self.row_ids
2366 .binary_search(&row_id)
2367 .map(|idx| self.num_tokens[idx])
2368 .unwrap_or(0)
2369 }
2370
2371 pub fn append(&mut self, row_id: u64, num_tokens: u32) -> u32 {
2374 self.row_ids.push(row_id);
2375 self.num_tokens.push(num_tokens);
2376 self.total_tokens += num_tokens as u64;
2377 self.row_ids.len() as u32 - 1
2378 }
2379}
2380
2381pub fn flat_full_text_search(
2382 batches: &[&RecordBatch],
2383 doc_col: &str,
2384 query: &str,
2385 tokenizer: Option<Box<dyn LanceTokenizer>>,
2386) -> Result<Vec<u64>> {
2387 if batches.is_empty() {
2388 return Ok(vec![]);
2389 }
2390
2391 if is_phrase_query(query) {
2392 return Err(Error::invalid_input(
2393 "phrase query is not supported for flat full text search, try using FTS index",
2394 location!(),
2395 ));
2396 }
2397
2398 match batches[0][doc_col].data_type() {
2399 DataType::Utf8 => do_flat_full_text_search::<i32>(batches, doc_col, query, tokenizer),
2400 DataType::LargeUtf8 => do_flat_full_text_search::<i64>(batches, doc_col, query, tokenizer),
2401 data_type => Err(Error::invalid_input(
2402 format!("unsupported data type {} for inverted index", data_type),
2403 location!(),
2404 )),
2405 }
2406}
2407
2408fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
2409 batches: &[&RecordBatch],
2410 doc_col: &str,
2411 query: &str,
2412 tokenizer: Option<Box<dyn LanceTokenizer>>,
2413) -> Result<Vec<u64>> {
2414 let mut results = Vec::new();
2415 let mut tokenizer =
2416 tokenizer.unwrap_or_else(|| InvertedIndexParams::default().build().unwrap());
2417 let query_tokens = collect_query_tokens(query, &mut tokenizer, None);
2418
2419 for batch in batches {
2420 let row_id_array = batch[ROW_ID].as_primitive::<UInt64Type>();
2421 let doc_array = batch[doc_col].as_string::<Offset>();
2422 for i in 0..row_id_array.len() {
2423 let doc = doc_array.value(i);
2424 let doc_tokens = collect_doc_tokens(doc, &mut tokenizer, Some(&query_tokens));
2425 if !doc_tokens.is_empty() {
2426 results.push(row_id_array.value(i));
2427 assert!(doc.contains(query));
2428 }
2429 }
2430 }
2431
2432 Ok(results)
2433}
2434
2435#[allow(clippy::too_many_arguments)]
2436pub fn flat_bm25_search(
2437 batch: RecordBatch,
2438 doc_col: &str,
2439 query_tokens: &Tokens,
2440 tokenizer: &mut Box<dyn LanceTokenizer>,
2441 scorer: &mut MemBM25Scorer,
2442 schema: SchemaRef,
2443) -> std::result::Result<RecordBatch, DataFusionError> {
2444 let doc_iter = iter_str_array(&batch[doc_col]);
2445 let mut scores = Vec::with_capacity(batch.num_rows());
2446 for doc in doc_iter {
2447 let Some(doc) = doc else {
2448 scores.push(0.0);
2449 continue;
2450 };
2451
2452 let doc_tokens = collect_doc_tokens(doc, tokenizer, None);
2453 scorer.update(&doc_tokens);
2454 let doc_tokens = doc_tokens
2455 .into_iter()
2456 .filter(|t| query_tokens.contains(t))
2457 .collect::<Vec<_>>();
2458
2459 let doc_norm = K1 * (1.0 - B + B * doc_tokens.len() as f32 / scorer.avg_doc_length());
2460 let mut doc_token_count = HashMap::new();
2461 for token in doc_tokens {
2462 doc_token_count
2463 .entry(token)
2464 .and_modify(|count| *count += 1)
2465 .or_insert(1);
2466 }
2467 let mut score = 0.0;
2468 for token in query_tokens {
2469 let freq = doc_token_count.get(token).copied().unwrap_or_default() as f32;
2470
2471 let idf = idf(scorer.num_docs_containing_token(token), scorer.num_docs());
2472 score += idf * (freq * (K1 + 1.0) / (freq + doc_norm));
2473 }
2474 scores.push(score);
2475 }
2476
2477 let score_col = Arc::new(Float32Array::from(scores)) as ArrayRef;
2478 let batch = batch
2479 .try_with_column(SCORE_FIELD.clone(), score_col)?
2480 .project_by_schema(&schema)?;
2481 Ok(batch)
2482}
2483
2484pub fn flat_bm25_search_stream(
2485 input: SendableRecordBatchStream,
2486 doc_col: String,
2487 query: String,
2488 index: &Option<InvertedIndex>,
2489 schema: SchemaRef,
2490) -> SendableRecordBatchStream {
2491 let mut tokenizer = match index {
2492 Some(index) => index.tokenizer(),
2493 None => Box::new(TextTokenizer::new(
2494 tantivy::tokenizer::TextAnalyzer::builder(
2495 tantivy::tokenizer::SimpleTokenizer::default(),
2496 )
2497 .build(),
2498 )),
2499 };
2500 let tokens = collect_query_tokens(&query, &mut tokenizer, None);
2501
2502 let mut bm25_scorer = match index {
2503 Some(index) => {
2504 let index_bm25_scorer =
2505 IndexBM25Scorer::new(index.partitions.iter().map(|p| p.as_ref()));
2506 if index_bm25_scorer.num_docs() == 0 {
2507 MemBM25Scorer::new(0, 0, HashMap::new())
2508 } else {
2509 let mut token_docs = HashMap::with_capacity(tokens.len());
2510 for token in &tokens {
2511 let token_nq = index_bm25_scorer.num_docs_containing_token(token).max(1);
2512 token_docs.insert(token.clone(), token_nq);
2513 }
2514 MemBM25Scorer::new(
2515 index_bm25_scorer.avg_doc_length() as u64 * index_bm25_scorer.num_docs() as u64,
2516 index_bm25_scorer.num_docs(),
2517 token_docs,
2518 )
2519 }
2520 }
2521 None => MemBM25Scorer::new(0, 0, HashMap::new()),
2522 };
2523
2524 let batch_schema = schema.clone();
2525 let stream = input.map(move |batch| {
2526 let batch = batch?;
2527
2528 let batch = flat_bm25_search(
2529 batch,
2530 &doc_col,
2531 &tokens,
2532 &mut tokenizer,
2533 &mut bm25_scorer,
2534 batch_schema.clone(),
2535 )?;
2536
2537 let score_col = batch[SCORE_COL].as_primitive::<Float32Type>();
2539 let mask = score_col
2540 .iter()
2541 .map(|score| score.is_some_and(|score| score > 0.0))
2542 .collect::<Vec<_>>();
2543 let mask = BooleanArray::from(mask);
2544 let batch = arrow::compute::filter_record_batch(&batch, &mask)?;
2545 debug_assert!(batch[ROW_ID].null_count() == 0, "flat FTS produces nulls");
2546 Ok(batch)
2547 });
2548
2549 Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream
2550}
2551
2552pub fn is_phrase_query(query: &str) -> bool {
2553 query.starts_with('\"') && query.ends_with('\"')
2554}
2555
2556#[cfg(test)]
2557mod tests {
2558 use crate::scalar::inverted::lance_tokenizer::DocType;
2559 use lance_core::cache::LanceCache;
2560 use lance_core::utils::tempfile::TempObjDir;
2561 use lance_io::object_store::ObjectStore;
2562
2563 use crate::metrics::NoOpMetricsCollector;
2564 use crate::prefilter::NoFilter;
2565 use crate::scalar::inverted::builder::{InnerBuilder, PositionRecorder};
2566 use crate::scalar::inverted::encoding::decompress_posting_list;
2567 use crate::scalar::inverted::query::{FtsSearchParams, Operator};
2568 use crate::scalar::lance_format::LanceIndexStore;
2569
2570 use super::*;
2571
2572 #[tokio::test]
2573 async fn test_posting_builder_remap() {
2574 let mut builder = PostingListBuilder::new(false);
2575 let n = BLOCK_SIZE + 3;
2576 for i in 0..n {
2577 builder.add(i as u32, PositionRecorder::Count(1));
2578 }
2579 let removed = vec![5, 7];
2580 builder.remap(&removed);
2581
2582 let mut expected = PostingListBuilder::new(false);
2583 for i in 0..n - removed.len() {
2584 expected.add(i as u32, PositionRecorder::Count(1));
2585 }
2586 assert_eq!(builder.doc_ids, expected.doc_ids);
2587 assert_eq!(builder.frequencies, expected.frequencies);
2588
2589 let batch = builder.to_batch(vec![1.0, 2.0]).unwrap();
2592 let (doc_ids, freqs) = decompress_posting_list(
2593 (n - removed.len()) as u32,
2594 batch[POSTING_COL]
2595 .as_list::<i32>()
2596 .value(0)
2597 .as_binary::<i64>(),
2598 )
2599 .unwrap();
2600 assert!(doc_ids
2601 .iter()
2602 .zip(expected.doc_ids.iter())
2603 .all(|(a, b)| a == b));
2604 assert!(freqs
2605 .iter()
2606 .zip(expected.frequencies.iter())
2607 .all(|(a, b)| a == b));
2608 }
2609
2610 #[tokio::test]
2611 async fn test_remap_to_empty_posting_list() {
2612 let tmpdir = TempObjDir::default();
2613 let store = Arc::new(LanceIndexStore::new(
2614 ObjectStore::local().into(),
2615 tmpdir.clone(),
2616 Arc::new(LanceCache::no_cache()),
2617 ));
2618
2619 let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default());
2620
2621 builder.tokens.add("lance".to_owned());
2626 builder.tokens.add("lake".to_owned());
2627 builder.posting_lists.push(PostingListBuilder::new(false));
2628 builder.posting_lists.push(PostingListBuilder::new(false));
2629 builder.posting_lists[0].add(0, PositionRecorder::Count(1));
2630 builder.posting_lists[1].add(1, PositionRecorder::Count(2));
2631 builder.posting_lists[1].add(2, PositionRecorder::Count(3));
2632 builder.docs.append(0, 1);
2633 builder.docs.append(1, 1);
2634 builder.docs.append(2, 1);
2635 builder.write(store.as_ref()).await.unwrap();
2636
2637 let index = InvertedPartition::load(
2638 store.clone(),
2639 0,
2640 None,
2641 &LanceCache::no_cache(),
2642 TokenSetFormat::default(),
2643 )
2644 .await
2645 .unwrap();
2646 let mut builder = index.into_builder().await.unwrap();
2647
2648 let mapping = HashMap::from([(0, None), (2, Some(3))]);
2649 builder.remap(&mapping).await.unwrap();
2650
2651 assert_eq!(builder.tokens.len(), 1);
2653 assert_eq!(builder.tokens.get("lake"), Some(0));
2654 assert_eq!(builder.posting_lists.len(), 1);
2655 assert_eq!(builder.posting_lists[0].len(), 2);
2656 assert_eq!(builder.docs.len(), 2);
2657 assert_eq!(builder.docs.row_id(0), 1);
2658 assert_eq!(builder.docs.row_id(1), 3);
2659
2660 builder.write(store.as_ref()).await.unwrap();
2661
2662 let mapping = HashMap::from([(1, None), (3, None)]);
2664 builder.remap(&mapping).await.unwrap();
2665
2666 assert_eq!(builder.tokens.len(), 0);
2667 assert_eq!(builder.posting_lists.len(), 0);
2668 assert_eq!(builder.docs.len(), 0);
2669
2670 builder.write(store.as_ref()).await.unwrap();
2671 }
2672
2673 #[tokio::test]
2674 async fn test_posting_cache_conflict_across_partitions() {
2675 let tmpdir = TempObjDir::default();
2676 let store = Arc::new(LanceIndexStore::new(
2677 ObjectStore::local().into(),
2678 tmpdir.clone(),
2679 Arc::new(LanceCache::no_cache()),
2680 ));
2681
2682 let mut builder1 = InnerBuilder::new(0, false, TokenSetFormat::default());
2684 builder1.tokens.add("test".to_owned());
2685 builder1.posting_lists.push(PostingListBuilder::new(false));
2686 builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
2687 builder1.docs.append(100, 1); builder1.write(store.as_ref()).await.unwrap();
2689
2690 let mut builder2 = InnerBuilder::new(1, false, TokenSetFormat::default());
2692 builder2.tokens.add("test".to_owned()); builder2.posting_lists.push(PostingListBuilder::new(false));
2694 builder2.posting_lists[0].add(0, PositionRecorder::Count(2));
2695 builder2.posting_lists[0].add(1, PositionRecorder::Count(1));
2696 builder2.posting_lists[0].add(2, PositionRecorder::Count(3));
2697 builder2.posting_lists[0].add(3, PositionRecorder::Count(1));
2698 builder2.docs.append(200, 2); builder2.docs.append(201, 1); builder2.docs.append(202, 3); builder2.docs.append(203, 1); builder2.write(store.as_ref()).await.unwrap();
2703
2704 let metadata = std::collections::HashMap::from_iter(vec![
2706 (
2707 "partitions".to_owned(),
2708 serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
2709 ),
2710 (
2711 "params".to_owned(),
2712 serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
2713 ),
2714 (
2715 TOKEN_SET_FORMAT_KEY.to_owned(),
2716 TokenSetFormat::default().to_string(),
2717 ),
2718 ]);
2719 let mut writer = store
2720 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
2721 .await
2722 .unwrap();
2723 writer.finish_with_metadata(metadata).await.unwrap();
2724
2725 let cache = Arc::new(LanceCache::with_capacity(4096));
2727 let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
2728 .await
2729 .unwrap();
2730
2731 assert_eq!(index.partitions.len(), 2);
2733 assert_eq!(index.partitions[0].tokens.len(), 1);
2734 assert_eq!(index.partitions[1].tokens.len(), 1);
2735
2736 if index.partitions[0].id() == 0 {
2741 assert_eq!(index.partitions[0].inverted_list.posting_len(0), 1);
2743 assert_eq!(index.partitions[1].inverted_list.posting_len(0), 4);
2744 assert_eq!(index.partitions[0].docs.len(), 1);
2745 assert_eq!(index.partitions[1].docs.len(), 4);
2746 } else {
2747 assert_eq!(index.partitions[0].inverted_list.posting_len(0), 4);
2749 assert_eq!(index.partitions[1].inverted_list.posting_len(0), 1);
2750 assert_eq!(index.partitions[0].docs.len(), 4);
2751 assert_eq!(index.partitions[1].docs.len(), 1);
2752 }
2753
2754 index.prewarm().await.unwrap();
2756
2757 let tokens = Arc::new(Tokens::new(vec!["test".to_string()], DocType::Text));
2758 let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
2759 let prefilter = Arc::new(NoFilter);
2760 let metrics = Arc::new(NoOpMetricsCollector);
2761
2762 let (row_ids, scores) = index
2763 .bm25_search(tokens, params, Operator::Or, prefilter, metrics)
2764 .await
2765 .unwrap();
2766
2767 assert_eq!(row_ids.len(), 5, "row_ids: {:?}", row_ids);
2770 assert!(!row_ids.is_empty(), "Should find at least some documents");
2771 assert_eq!(row_ids.len(), scores.len());
2772
2773 for &score in &scores {
2775 assert!(score > 0.0, "All scores should be positive");
2776 }
2777
2778 assert!(
2780 row_ids.contains(&100),
2781 "Should contain row_id from partition 0"
2782 );
2783 assert!(
2784 row_ids.iter().any(|&id| id >= 200),
2785 "Should contain row_id from partition 1"
2786 );
2787 }
2788
2789 #[test]
2790 fn test_block_max_scores_capacity_matches_block_count() {
2791 let mut docs = DocSet::default();
2792 let num_docs = BLOCK_SIZE * 3 + 7;
2793 let doc_ids = (0..num_docs as u32).collect::<Vec<_>>();
2794 for doc_id in &doc_ids {
2795 docs.append(*doc_id as u64, 1);
2796 }
2797
2798 let freqs = vec![1_u32; doc_ids.len()];
2799 let block_max_scores = docs.calculate_block_max_scores(doc_ids.iter(), freqs.iter());
2800 let expected_blocks = doc_ids.len().div_ceil(BLOCK_SIZE);
2801
2802 assert_eq!(block_max_scores.len(), expected_blocks);
2803 assert_eq!(block_max_scores.capacity(), expected_blocks);
2804 }
2805
2806 #[tokio::test]
2807 async fn test_bm25_search_uses_global_idf() {
2808 let tmpdir = TempObjDir::default();
2809 let store = Arc::new(LanceIndexStore::new(
2810 ObjectStore::local().into(),
2811 tmpdir.clone(),
2812 Arc::new(LanceCache::no_cache()),
2813 ));
2814
2815 let mut builder0 = InnerBuilder::new(0, false, TokenSetFormat::default());
2817 builder0.tokens.add("alpha".to_owned());
2818 builder0.tokens.add("beta".to_owned());
2819 builder0.posting_lists.push(PostingListBuilder::new(false));
2820 builder0.posting_lists.push(PostingListBuilder::new(false));
2821 builder0.posting_lists[0].add(0, PositionRecorder::Count(1));
2822 builder0.posting_lists[1].add(1, PositionRecorder::Count(1));
2823 builder0.posting_lists[1].add(2, PositionRecorder::Count(1));
2824 builder0.docs.append(100, 1);
2825 builder0.docs.append(101, 1);
2826 builder0.docs.append(102, 1);
2827 builder0.write(store.as_ref()).await.unwrap();
2828
2829 let mut builder1 = InnerBuilder::new(1, false, TokenSetFormat::default());
2831 builder1.tokens.add("alpha".to_owned());
2832 builder1.posting_lists.push(PostingListBuilder::new(false));
2833 builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
2834 builder1.docs.append(200, 1);
2835 builder1.write(store.as_ref()).await.unwrap();
2836
2837 let metadata = std::collections::HashMap::from_iter(vec![
2838 (
2839 "partitions".to_owned(),
2840 serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
2841 ),
2842 (
2843 "params".to_owned(),
2844 serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
2845 ),
2846 (
2847 TOKEN_SET_FORMAT_KEY.to_owned(),
2848 TokenSetFormat::default().to_string(),
2849 ),
2850 ]);
2851 let mut writer = store
2852 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
2853 .await
2854 .unwrap();
2855 writer.finish_with_metadata(metadata).await.unwrap();
2856
2857 let cache = Arc::new(LanceCache::with_capacity(4096));
2858 let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
2859 .await
2860 .unwrap();
2861
2862 let tokens = Arc::new(Tokens::new(vec!["alpha".to_string()], DocType::Text));
2863 let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
2864 let prefilter = Arc::new(NoFilter);
2865 let metrics = Arc::new(NoOpMetricsCollector);
2866
2867 let (row_ids, scores) = index
2868 .bm25_search(tokens, params, Operator::Or, prefilter, metrics)
2869 .await
2870 .unwrap();
2871
2872 assert_eq!(row_ids.len(), 2);
2873 assert!(row_ids.contains(&100));
2874 assert!(row_ids.contains(&200));
2875 assert_eq!(row_ids.len(), scores.len());
2876
2877 let expected_idf = idf(2, 4);
2878 for score in scores {
2879 assert!(
2880 (score - expected_idf).abs() < 1e-6,
2881 "score: {}, expected: {}",
2882 score,
2883 expected_idf
2884 );
2885 }
2886 }
2887}