1use std::fmt::{Debug, Display};
5use std::sync::Arc;
6use std::{
7 cmp::{min, Reverse},
8 collections::BinaryHeap,
9};
10use std::{collections::HashMap, ops::Range};
11
12use crate::metrics::NoOpMetricsCollector;
13use crate::prefilter::NoFilter;
14use crate::scalar::registry::{TrainingCriteria, TrainingOrdering};
15use arrow::datatypes::{self, Float32Type, Int32Type, UInt64Type};
16use arrow::{
17 array::{
18 AsArray, LargeBinaryBuilder, ListBuilder, StringBuilder, UInt32Builder, UInt64Builder,
19 },
20 buffer::OffsetBuffer,
21};
22use arrow::{buffer::ScalarBuffer, datatypes::UInt32Type};
23use arrow_array::{
24 Array, ArrayRef, BooleanArray, Float32Array, LargeBinaryArray, ListArray, OffsetSizeTrait,
25 RecordBatch, UInt32Array, UInt64Array,
26};
27use arrow_schema::{DataType, Field, Schema, SchemaRef};
28use async_trait::async_trait;
29use datafusion::execution::SendableRecordBatchStream;
30use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
31use datafusion_common::DataFusionError;
32use deepsize::DeepSizeOf;
33use fst::{Automaton, IntoStreamer, Streamer};
34use futures::{stream, FutureExt, StreamExt, TryStreamExt};
35use itertools::Itertools;
36use lance_arrow::{iter_str_array, RecordBatchExt};
37use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache};
38use lance_core::utils::mask::RowIdTreeMap;
39use lance_core::utils::{
40 mask::RowIdMask,
41 tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS},
42};
43use lance_core::{
44 container::list::ExpLinkedList,
45 utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu},
46};
47use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD};
48use roaring::RoaringBitmap;
49use snafu::location;
50use std::sync::LazyLock;
51use tokio::task::spawn_blocking;
52use tracing::{info, instrument};
53
54use super::{
55 builder::{
56 doc_file_path, inverted_list_schema, posting_file_path, token_file_path, ScoredDoc,
57 BLOCK_SIZE,
58 },
59 iter::PlainPostingListIterator,
60 query::*,
61 scorer::{idf, IndexBM25Scorer, Scorer, B, K1},
62};
63use super::{
64 builder::{InnerBuilder, PositionRecorder},
65 encoding::compress_posting_list,
66 iter::CompressedPostingListIterator,
67};
68use super::{
69 encoding::compress_positions,
70 iter::{PostingListIterator, TokenIterator, TokenSource},
71};
72use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams};
73use crate::frag_reuse::FragReuseIndex;
74use crate::pbold;
75use crate::scalar::inverted::lance_tokenizer::TextTokenizer;
76use crate::scalar::inverted::scorer::MemBM25Scorer;
77use crate::scalar::inverted::tokenizer::lance_tokenizer::LanceTokenizer;
78use crate::scalar::{
79 AnyQuery, BuiltinIndexType, CreatedIndex, IndexReader, IndexStore, MetricsCollector,
80 ScalarIndex, ScalarIndexParams, SearchResult, TokenQuery, UpdateCriteria,
81};
82use crate::Index;
83use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys};
84use std::str::FromStr;
85
86pub const INVERTED_INDEX_VERSION: u32 = 1;
89pub const TOKENS_FILE: &str = "tokens.lance";
90pub const INVERT_LIST_FILE: &str = "invert.lance";
91pub const DOCS_FILE: &str = "docs.lance";
92pub const METADATA_FILE: &str = "metadata.lance";
93
94pub const TOKEN_COL: &str = "_token";
95pub const TOKEN_ID_COL: &str = "_token_id";
96pub const TOKEN_FST_BYTES_COL: &str = "_token_fst_bytes";
97pub const TOKEN_NEXT_ID_COL: &str = "_token_next_id";
98pub const TOKEN_TOTAL_LENGTH_COL: &str = "_token_total_length";
99pub const FREQUENCY_COL: &str = "_frequency";
100pub const POSITION_COL: &str = "_position";
101pub const COMPRESSED_POSITION_COL: &str = "_compressed_position";
102pub const POSTING_COL: &str = "_posting";
103pub const MAX_SCORE_COL: &str = "_max_score";
104pub const LENGTH_COL: &str = "_length";
105pub const BLOCK_MAX_SCORE_COL: &str = "_block_max_score";
106pub const NUM_TOKEN_COL: &str = "_num_tokens";
107pub const SCORE_COL: &str = "_score";
108pub const TOKEN_SET_FORMAT_KEY: &str = "token_set_format";
109
110pub static SCORE_FIELD: LazyLock<Field> =
111 LazyLock::new(|| Field::new(SCORE_COL, DataType::Float32, true));
112pub static FTS_SCHEMA: LazyLock<SchemaRef> =
113 LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone(), SCORE_FIELD.clone()])));
114static ROW_ID_SCHEMA: LazyLock<SchemaRef> =
115 LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()])));
116
117#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)]
118pub enum TokenSetFormat {
119 Arrow,
120 #[default]
121 Fst,
122}
123
124impl Display for TokenSetFormat {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 match self {
127 Self::Arrow => f.write_str("arrow"),
128 Self::Fst => f.write_str("fst"),
129 }
130 }
131}
132
133impl FromStr for TokenSetFormat {
134 type Err = Error;
135
136 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
137 match s.trim() {
138 "" => Ok(Self::Arrow),
139 "arrow" => Ok(Self::Arrow),
140 "fst" => Ok(Self::Fst),
141 other => Err(Error::Index {
142 message: format!("unsupported token set format {}", other),
143 location: location!(),
144 }),
145 }
146 }
147}
148
149impl DeepSizeOf for TokenSetFormat {
150 fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
151 0
152 }
153}
154
155#[derive(Clone)]
156pub struct InvertedIndex {
157 params: InvertedIndexParams,
158 store: Arc<dyn IndexStore>,
159 tokenizer: Box<dyn LanceTokenizer>,
160 token_set_format: TokenSetFormat,
161 pub(crate) partitions: Vec<Arc<InvertedPartition>>,
162}
163
164impl Debug for InvertedIndex {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 f.debug_struct("InvertedIndex")
167 .field("params", &self.params)
168 .field("token_set_format", &self.token_set_format)
169 .field("partitions", &self.partitions)
170 .finish()
171 }
172}
173
174impl DeepSizeOf for InvertedIndex {
175 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
176 self.partitions.deep_size_of_children(context)
177 }
178}
179
180impl InvertedIndex {
181 fn to_builder(&self) -> InvertedIndexBuilder {
182 self.to_builder_with_offset(None)
183 }
184
185 fn to_builder_with_offset(&self, fragment_mask: Option<u64>) -> InvertedIndexBuilder {
186 if self.is_legacy() {
187 InvertedIndexBuilder::from_existing_index(
189 self.params.clone(),
190 None,
191 Vec::new(),
192 self.token_set_format,
193 fragment_mask,
194 )
195 } else {
196 let partitions = match fragment_mask {
197 Some(fragment_mask) => self
198 .partitions
199 .iter()
200 .filter(|part| part.belongs_to_fragment(fragment_mask))
204 .map(|part| part.id())
205 .collect(),
206 None => self.partitions.iter().map(|part| part.id()).collect(),
207 };
208
209 InvertedIndexBuilder::from_existing_index(
210 self.params.clone(),
211 Some(self.store.clone()),
212 partitions,
213 self.token_set_format,
214 fragment_mask,
215 )
216 }
217 }
218
219 pub fn tokenizer(&self) -> Box<dyn LanceTokenizer> {
220 self.tokenizer.clone()
221 }
222
223 pub fn params(&self) -> &InvertedIndexParams {
224 &self.params
225 }
226
227 #[instrument(level = "debug", skip_all)]
233 pub async fn bm25_search(
234 &self,
235 tokens: Arc<Tokens>,
236 params: Arc<FtsSearchParams>,
237 operator: Operator,
238 prefilter: Arc<dyn PreFilter>,
239 metrics: Arc<dyn MetricsCollector>,
240 ) -> Result<(Vec<u64>, Vec<f32>)> {
241 let limit = params.limit.unwrap_or(usize::MAX);
242 if limit == 0 {
243 return Ok((Vec::new(), Vec::new()));
244 }
245 let mask = prefilter.mask();
246
247 let mut candidates = BinaryHeap::new();
248 let parts = self
249 .partitions
250 .iter()
251 .map(|part| {
252 let part = part.clone();
253 let tokens = tokens.clone();
254 let params = params.clone();
255 let mask = mask.clone();
256 let metrics = metrics.clone();
257 async move {
258 let postings = part
259 .load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref())
260 .await?;
261 if postings.is_empty() {
262 return Ok(Vec::new());
263 }
264 let params = params.clone();
265 let mask = mask.clone();
266 let metrics = metrics.clone();
267 spawn_cpu(move || {
268 part.bm25_search(
269 params.as_ref(),
270 operator,
271 mask,
272 postings,
273 metrics.as_ref(),
274 )
275 })
276 .await
277 }
278 })
279 .collect::<Vec<_>>();
280 let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
281 let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
282 while let Some(res) = parts.try_next().await? {
283 for DocCandidate {
284 row_id,
285 freqs,
286 doc_length,
287 } in res
288 {
289 let mut score = 0.0;
290 for (token, freq) in freqs.into_iter() {
291 score += scorer.score(token.as_str(), freq, doc_length);
292 }
293 if candidates.len() < limit {
294 candidates.push(Reverse(ScoredDoc::new(row_id, score)));
295 } else if candidates.peek().unwrap().0.score.0 < score {
296 candidates.pop();
297 candidates.push(Reverse(ScoredDoc::new(row_id, score)));
298 }
299 }
300 }
301
302 Ok(candidates
303 .into_sorted_vec()
304 .into_iter()
305 .map(|Reverse(doc)| (doc.row_id, doc.score.0))
306 .unzip())
307 }
308
309 async fn load_legacy_index(
310 store: Arc<dyn IndexStore>,
311 frag_reuse_index: Option<Arc<FragReuseIndex>>,
312 index_cache: &LanceCache,
313 ) -> Result<Arc<Self>> {
314 log::warn!("loading legacy FTS index");
315 let tokens_fut = tokio::spawn({
316 let store = store.clone();
317 async move {
318 let token_reader = store.open_index_file(TOKENS_FILE).await?;
319 let tokenizer = token_reader
320 .schema()
321 .metadata
322 .get("tokenizer")
323 .map(|s| serde_json::from_str::<InvertedIndexParams>(s))
324 .transpose()?
325 .unwrap_or_default();
326 let tokens = TokenSet::load(token_reader, TokenSetFormat::Arrow).await?;
327 Result::Ok((tokenizer, tokens))
328 }
329 });
330 let invert_list_fut = tokio::spawn({
331 let store = store.clone();
332 let index_cache_clone = index_cache.clone();
333 async move {
334 let invert_list_reader = store.open_index_file(INVERT_LIST_FILE).await?;
335 let invert_list =
336 PostingListReader::try_new(invert_list_reader, &index_cache_clone).await?;
337 Result::Ok(Arc::new(invert_list))
338 }
339 });
340 let docs_fut = tokio::spawn({
341 let store = store.clone();
342 async move {
343 let docs_reader = store.open_index_file(DOCS_FILE).await?;
344 let docs = DocSet::load(docs_reader, true, frag_reuse_index).await?;
345 Result::Ok(docs)
346 }
347 });
348
349 let (tokenizer_config, tokens) = tokens_fut.await??;
350 let inverted_list = invert_list_fut.await??;
351 let docs = docs_fut.await??;
352
353 let tokenizer = tokenizer_config.build()?;
354
355 Ok(Arc::new(Self {
356 params: tokenizer_config,
357 store: store.clone(),
358 tokenizer,
359 token_set_format: TokenSetFormat::Arrow,
360 partitions: vec![Arc::new(InvertedPartition {
361 id: 0,
362 store,
363 tokens,
364 inverted_list,
365 docs,
366 token_set_format: TokenSetFormat::Arrow,
367 })],
368 }))
369 }
370
371 pub fn is_legacy(&self) -> bool {
372 self.partitions.len() == 1 && self.partitions[0].is_legacy()
373 }
374
375 pub async fn load(
376 store: Arc<dyn IndexStore>,
377 frag_reuse_index: Option<Arc<FragReuseIndex>>,
378 index_cache: &LanceCache,
379 ) -> Result<Arc<Self>>
380 where
381 Self: Sized,
382 {
383 match store.open_index_file(METADATA_FILE).await {
388 Ok(reader) => {
389 let params = reader.schema().metadata.get("params").ok_or(Error::Index {
390 message: "params not found in metadata".to_owned(),
391 location: location!(),
392 })?;
393 let params = serde_json::from_str::<InvertedIndexParams>(params)?;
394 let partitions =
395 reader
396 .schema()
397 .metadata
398 .get("partitions")
399 .ok_or(Error::Index {
400 message: "partitions not found in metadata".to_owned(),
401 location: location!(),
402 })?;
403 let partitions: Vec<u64> = serde_json::from_str(partitions)?;
404 let token_set_format = reader
405 .schema()
406 .metadata
407 .get(TOKEN_SET_FORMAT_KEY)
408 .map(|name| TokenSetFormat::from_str(name))
409 .transpose()?
410 .unwrap_or(TokenSetFormat::Arrow);
411
412 let format = token_set_format;
413 let partitions = partitions.into_iter().map(|id| {
414 let store = store.clone();
415 let frag_reuse_index_clone = frag_reuse_index.clone();
416 let index_cache_for_part =
417 index_cache.with_key_prefix(format!("part-{}", id).as_str());
418 let token_set_format = format;
419 async move {
420 Result::Ok(Arc::new(
421 InvertedPartition::load(
422 store,
423 id,
424 frag_reuse_index_clone,
425 &index_cache_for_part,
426 token_set_format,
427 )
428 .await?,
429 ))
430 }
431 });
432 let partitions = stream::iter(partitions)
433 .buffer_unordered(store.io_parallelism())
434 .try_collect::<Vec<_>>()
435 .await?;
436
437 let tokenizer = params.build()?;
438 Ok(Arc::new(Self {
439 params,
440 store,
441 tokenizer,
442 token_set_format,
443 partitions,
444 }))
445 }
446 Err(_) => {
447 Self::load_legacy_index(store, frag_reuse_index, index_cache).await
449 }
450 }
451 }
452}
453
454#[async_trait]
455impl Index for InvertedIndex {
456 fn as_any(&self) -> &dyn std::any::Any {
457 self
458 }
459
460 fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
461 self
462 }
463
464 fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
465 Err(Error::invalid_input(
466 "inverted index cannot be cast to vector index",
467 location!(),
468 ))
469 }
470
471 fn statistics(&self) -> Result<serde_json::Value> {
472 let num_tokens = self
473 .partitions
474 .iter()
475 .map(|part| part.tokens.len())
476 .sum::<usize>();
477 let num_docs = self
478 .partitions
479 .iter()
480 .map(|part| part.docs.len())
481 .sum::<usize>();
482 Ok(serde_json::json!({
483 "params": self.params,
484 "num_tokens": num_tokens,
485 "num_docs": num_docs,
486 }))
487 }
488
489 async fn prewarm(&self) -> Result<()> {
490 for part in &self.partitions {
491 part.inverted_list.prewarm().await?;
492 }
493 Ok(())
494 }
495
496 fn index_type(&self) -> crate::IndexType {
497 crate::IndexType::Inverted
498 }
499
500 async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
501 unimplemented!()
502 }
503}
504
505impl InvertedIndex {
506 async fn do_search(&self, text: &str) -> Result<RecordBatch> {
508 let params = FtsSearchParams::new();
509 let mut tokenizer = self.tokenizer.clone();
510 let tokens = collect_query_tokens(text, &mut tokenizer, None);
511
512 let (doc_ids, _) = self
513 .bm25_search(
514 Arc::new(tokens),
515 params.into(),
516 Operator::And,
517 Arc::new(NoFilter),
518 Arc::new(NoOpMetricsCollector),
519 )
520 .boxed()
521 .await?;
522
523 Ok(RecordBatch::try_new(
524 ROW_ID_SCHEMA.clone(),
525 vec![Arc::new(UInt64Array::from(doc_ids))],
526 )?)
527 }
528}
529
530#[async_trait]
531impl ScalarIndex for InvertedIndex {
532 #[instrument(level = "debug", skip_all)]
534 async fn search(
535 &self,
536 query: &dyn AnyQuery,
537 _metrics: &dyn MetricsCollector,
538 ) -> Result<SearchResult> {
539 let query = query.as_any().downcast_ref::<TokenQuery>().unwrap();
540
541 match query {
542 TokenQuery::TokensContains(text) => {
543 let records = self.do_search(text).await?;
544 let row_ids = records
545 .column(0)
546 .as_any()
547 .downcast_ref::<UInt64Array>()
548 .unwrap();
549 let row_ids = row_ids.iter().flatten().collect_vec();
550 Ok(SearchResult::AtMost(RowIdTreeMap::from_iter(row_ids)))
551 }
552 }
553 }
554
555 fn can_remap(&self) -> bool {
556 true
557 }
558
559 async fn remap(
560 &self,
561 mapping: &HashMap<u64, Option<u64>>,
562 dest_store: &dyn IndexStore,
563 ) -> Result<CreatedIndex> {
564 self.to_builder()
565 .remap(mapping, self.store.clone(), dest_store)
566 .await?;
567
568 let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
569
570 let index_version = match self.token_set_format {
572 TokenSetFormat::Arrow => 0,
573 TokenSetFormat::Fst => INVERTED_INDEX_VERSION,
574 };
575
576 Ok(CreatedIndex {
577 index_details: prost_types::Any::from_msg(&details).unwrap(),
578 index_version,
579 })
580 }
581
582 async fn update(
583 &self,
584 new_data: SendableRecordBatchStream,
585 dest_store: &dyn IndexStore,
586 ) -> Result<CreatedIndex> {
587 self.to_builder().update(new_data, dest_store).await?;
588
589 let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
590
591 let index_version = match self.token_set_format {
593 TokenSetFormat::Arrow => 0,
594 TokenSetFormat::Fst => INVERTED_INDEX_VERSION,
595 };
596
597 Ok(CreatedIndex {
598 index_details: prost_types::Any::from_msg(&details).unwrap(),
599 index_version,
600 })
601 }
602
603 fn update_criteria(&self) -> UpdateCriteria {
604 let criteria = TrainingCriteria::new(TrainingOrdering::None).with_row_id();
605 if self.is_legacy() {
606 UpdateCriteria::requires_old_data(criteria)
607 } else {
608 UpdateCriteria::only_new_data(criteria)
609 }
610 }
611
612 fn derive_index_params(&self) -> Result<ScalarIndexParams> {
613 let mut params = self.params.clone();
614 if params.base_tokenizer.is_empty() {
615 params.base_tokenizer = "simple".to_string();
616 }
617
618 let params_json = serde_json::to_string(¶ms)?;
619
620 Ok(ScalarIndexParams {
621 index_type: BuiltinIndexType::Inverted.as_str().to_string(),
622 params: Some(params_json),
623 })
624 }
625}
626
627#[derive(Debug, Clone, DeepSizeOf)]
628pub struct InvertedPartition {
629 id: u64,
631 store: Arc<dyn IndexStore>,
632 pub(crate) tokens: TokenSet,
633 pub(crate) inverted_list: Arc<PostingListReader>,
634 pub(crate) docs: DocSet,
635 token_set_format: TokenSetFormat,
636}
637
638impl InvertedPartition {
639 pub fn belongs_to_fragment(&self, fragment_mask: u64) -> bool {
650 (self.id() & fragment_mask) == fragment_mask
651 }
652
653 pub fn id(&self) -> u64 {
654 self.id
655 }
656
657 pub fn store(&self) -> &dyn IndexStore {
658 self.store.as_ref()
659 }
660
661 pub fn is_legacy(&self) -> bool {
662 self.inverted_list.lengths.is_none()
663 }
664
665 pub async fn load(
666 store: Arc<dyn IndexStore>,
667 id: u64,
668 frag_reuse_index: Option<Arc<FragReuseIndex>>,
669 index_cache: &LanceCache,
670 token_set_format: TokenSetFormat,
671 ) -> Result<Self> {
672 let token_file = store.open_index_file(&token_file_path(id)).await?;
673 let tokens = TokenSet::load(token_file, token_set_format).await?;
674 let invert_list_file = store.open_index_file(&posting_file_path(id)).await?;
675 let inverted_list = PostingListReader::try_new(invert_list_file, index_cache).await?;
676 let docs_file = store.open_index_file(&doc_file_path(id)).await?;
677 let docs = DocSet::load(docs_file, false, frag_reuse_index).await?;
678
679 Ok(Self {
680 id,
681 store,
682 tokens,
683 inverted_list: Arc::new(inverted_list),
684 docs,
685 token_set_format,
686 })
687 }
688
689 fn map(&self, token: &str) -> Option<u32> {
690 self.tokens.get(token)
691 }
692
693 pub fn expand_fuzzy(&self, tokens: &Tokens, params: &FtsSearchParams) -> Result<Tokens> {
694 let mut new_tokens = Vec::with_capacity(min(tokens.len(), params.max_expansions));
695 for token in tokens {
696 let fuzziness = match params.fuzziness {
697 Some(fuzziness) => fuzziness,
698 None => MatchQuery::auto_fuzziness(token),
699 };
700 let lev =
701 fst::automaton::Levenshtein::new(token, fuzziness).map_err(|e| Error::Index {
702 message: format!("failed to construct the fuzzy query: {}", e),
703 location: location!(),
704 })?;
705
706 let base_len = tokens.token_type().prefix_len(token) as u32;
707 if let TokenMap::Fst(ref map) = self.tokens.tokens {
708 match base_len + params.prefix_length {
709 0 => take_fst_keys(map.search(lev), &mut new_tokens, params.max_expansions),
710 prefix_length => {
711 let prefix = &token[..min(prefix_length as usize, token.len())];
712 let prefix = fst::automaton::Str::new(prefix).starts_with();
713 take_fst_keys(
714 map.search(lev.intersection(prefix)),
715 &mut new_tokens,
716 params.max_expansions,
717 )
718 }
719 }
720 } else {
721 return Err(Error::Index {
722 message: "tokens is not fst, which is not expected".to_owned(),
723 location: location!(),
724 });
725 }
726 }
727 Ok(Tokens::new(new_tokens, tokens.token_type().clone()))
728 }
729
730 #[instrument(level = "debug", skip_all)]
734 pub async fn load_posting_lists(
735 &self,
736 tokens: &Tokens,
737 params: &FtsSearchParams,
738 metrics: &dyn MetricsCollector,
739 ) -> Result<Vec<PostingIterator>> {
740 let is_fuzzy = matches!(params.fuzziness, Some(n) if n != 0);
741 let is_phrase_query = params.phrase_slop.is_some();
742 let tokens = match is_fuzzy {
743 true => self.expand_fuzzy(tokens, params)?,
744 false => tokens.clone(),
745 };
746 let mut token_ids = Vec::with_capacity(tokens.len());
747 for token in tokens {
748 let token_id = self.map(&token);
749 if let Some(token_id) = token_id {
750 token_ids.push((token_id, token));
751 } else if is_phrase_query {
752 return Ok(Vec::new());
754 }
755 }
756 if token_ids.is_empty() {
757 return Ok(Vec::new());
758 }
759 if !is_phrase_query {
760 token_ids.sort_unstable_by_key(|(token_id, _)| *token_id);
762 token_ids.dedup_by_key(|(token_id, _)| *token_id);
763 }
764
765 let num_docs = self.docs.len();
766 stream::iter(token_ids)
767 .enumerate()
768 .map(|(position, (token_id, token))| async move {
769 let posting = self
770 .inverted_list
771 .posting_list(token_id, is_phrase_query, metrics)
772 .await?;
773
774 Result::Ok(PostingIterator::new(
775 token,
776 token_id,
777 position as u32,
778 posting,
779 num_docs,
780 ))
781 })
782 .buffered(self.store.io_parallelism())
783 .try_collect::<Vec<_>>()
784 .await
785 }
786
787 #[instrument(level = "debug", skip_all)]
788 pub fn bm25_search(
789 &self,
790 params: &FtsSearchParams,
791 operator: Operator,
792 mask: Arc<RowIdMask>,
793 postings: Vec<PostingIterator>,
794 metrics: &dyn MetricsCollector,
795 ) -> Result<Vec<DocCandidate>> {
796 if postings.is_empty() {
797 return Ok(Vec::new());
798 }
799
800 let scorer = IndexBM25Scorer::new(std::iter::once(self));
802 let mut wand = Wand::new(operator, postings.into_iter(), &self.docs, scorer);
803 let hits = wand.search(params, mask, metrics)?;
804 Ok(hits)
806 }
807
808 pub async fn into_builder(self) -> Result<InnerBuilder> {
809 let mut builder = InnerBuilder::new(
810 self.id,
811 self.inverted_list.has_positions(),
812 self.token_set_format,
813 );
814 builder.tokens = self.tokens;
815 builder.docs = self.docs;
816
817 builder
818 .posting_lists
819 .reserve_exact(self.inverted_list.len());
820 for posting_list in self
821 .inverted_list
822 .read_all(self.inverted_list.has_positions())
823 .await?
824 {
825 let posting_list = posting_list?;
826 builder
827 .posting_lists
828 .push(posting_list.into_builder(&builder.docs));
829 }
830 Ok(builder)
831 }
832}
833
834#[derive(Debug, Clone)]
837pub enum TokenMap {
838 HashMap(HashMap<String, u32>),
839 Fst(fst::Map<Vec<u8>>),
840}
841
842impl Default for TokenMap {
843 fn default() -> Self {
844 Self::HashMap(HashMap::new())
845 }
846}
847
848impl DeepSizeOf for TokenMap {
849 fn deep_size_of_children(&self, ctx: &mut deepsize::Context) -> usize {
850 match self {
851 Self::HashMap(map) => map.deep_size_of_children(ctx),
852 Self::Fst(map) => map.as_fst().size(),
853 }
854 }
855}
856
857impl TokenMap {
858 pub fn len(&self) -> usize {
859 match self {
860 Self::HashMap(map) => map.len(),
861 Self::Fst(map) => map.len(),
862 }
863 }
864
865 pub fn is_empty(&self) -> bool {
866 self.len() == 0
867 }
868}
869
870#[derive(Debug, Clone, Default, DeepSizeOf)]
872pub struct TokenSet {
873 pub(crate) tokens: TokenMap,
875 pub(crate) next_id: u32,
876 total_length: usize,
877}
878
879impl TokenSet {
880 pub fn into_mut(self) -> Self {
881 let tokens = match self.tokens {
882 TokenMap::HashMap(map) => map,
883 TokenMap::Fst(map) => {
884 let mut new_map = HashMap::with_capacity(map.len());
885 let mut stream = map.into_stream();
886 while let Some((token, token_id)) = stream.next() {
887 new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
888 }
889
890 new_map
891 }
892 };
893
894 Self {
895 tokens: TokenMap::HashMap(tokens),
896 next_id: self.next_id,
897 total_length: self.total_length,
898 }
899 }
900
901 pub fn len(&self) -> usize {
902 self.tokens.len()
903 }
904
905 pub fn is_empty(&self) -> bool {
906 self.len() == 0
907 }
908
909 pub(crate) fn iter(&self) -> TokenIterator<'_> {
910 TokenIterator::new(match &self.tokens {
911 TokenMap::HashMap(map) => TokenSource::HashMap(map.iter()),
912 TokenMap::Fst(map) => TokenSource::Fst(map.stream()),
913 })
914 }
915
916 pub fn to_batch(self, format: TokenSetFormat) -> Result<RecordBatch> {
917 match format {
918 TokenSetFormat::Arrow => self.into_arrow_batch(),
919 TokenSetFormat::Fst => self.into_fst_batch(),
920 }
921 }
922
923 fn into_arrow_batch(self) -> Result<RecordBatch> {
924 let mut token_builder = StringBuilder::with_capacity(self.tokens.len(), self.total_length);
925 let mut token_id_builder = UInt32Builder::with_capacity(self.tokens.len());
926
927 match self.tokens {
928 TokenMap::Fst(map) => {
929 let mut stream = map.stream();
930 while let Some((token, token_id)) = stream.next() {
931 token_builder.append_value(String::from_utf8_lossy(token));
932 token_id_builder.append_value(token_id as u32);
933 }
934 }
935 TokenMap::HashMap(map) => {
936 for (token, token_id) in map.into_iter().sorted_unstable() {
937 token_builder.append_value(token);
938 token_id_builder.append_value(token_id);
939 }
940 }
941 }
942
943 let token_col = token_builder.finish();
944 let token_id_col = token_id_builder.finish();
945
946 let schema = arrow_schema::Schema::new(vec![
947 arrow_schema::Field::new(TOKEN_COL, DataType::Utf8, false),
948 arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false),
949 ]);
950
951 let batch = RecordBatch::try_new(
952 Arc::new(schema),
953 vec![
954 Arc::new(token_col) as ArrayRef,
955 Arc::new(token_id_col) as ArrayRef,
956 ],
957 )?;
958 Ok(batch)
959 }
960
961 fn into_fst_batch(mut self) -> Result<RecordBatch> {
962 let fst_map = match std::mem::take(&mut self.tokens) {
963 TokenMap::Fst(map) => map,
964 TokenMap::HashMap(map) => Self::build_fst_from_map(map)?,
965 };
966 let bytes = fst_map.into_fst().into_inner();
967
968 let mut fst_builder = LargeBinaryBuilder::with_capacity(1, bytes.len());
969 fst_builder.append_value(bytes);
970 let fst_col = fst_builder.finish();
971
972 let mut next_id_builder = UInt32Builder::with_capacity(1);
973 next_id_builder.append_value(self.next_id);
974 let next_id_col = next_id_builder.finish();
975
976 let mut total_length_builder = UInt64Builder::with_capacity(1);
977 total_length_builder.append_value(self.total_length as u64);
978 let total_length_col = total_length_builder.finish();
979
980 let schema = arrow_schema::Schema::new(vec![
981 arrow_schema::Field::new(TOKEN_FST_BYTES_COL, DataType::LargeBinary, false),
982 arrow_schema::Field::new(TOKEN_NEXT_ID_COL, DataType::UInt32, false),
983 arrow_schema::Field::new(TOKEN_TOTAL_LENGTH_COL, DataType::UInt64, false),
984 ]);
985
986 let batch = RecordBatch::try_new(
987 Arc::new(schema),
988 vec![
989 Arc::new(fst_col) as ArrayRef,
990 Arc::new(next_id_col) as ArrayRef,
991 Arc::new(total_length_col) as ArrayRef,
992 ],
993 )?;
994 Ok(batch)
995 }
996
997 fn build_fst_from_map(map: HashMap<String, u32>) -> Result<fst::Map<Vec<u8>>> {
998 let mut entries: Vec<_> = map.into_iter().collect();
999 entries.sort_unstable_by(|(lhs, _), (rhs, _)| lhs.cmp(rhs));
1000 let mut builder = fst::MapBuilder::memory();
1001 for (token, token_id) in entries {
1002 builder
1003 .insert(&token, token_id as u64)
1004 .map_err(|e| Error::Index {
1005 message: format!("failed to insert token {}: {}", token, e),
1006 location: location!(),
1007 })?;
1008 }
1009 Ok(builder.into_map())
1010 }
1011
1012 pub async fn load(reader: Arc<dyn IndexReader>, format: TokenSetFormat) -> Result<Self> {
1013 match format {
1014 TokenSetFormat::Arrow => Self::load_arrow(reader).await,
1015 TokenSetFormat::Fst => Self::load_fst(reader).await,
1016 }
1017 }
1018
1019 async fn load_arrow(reader: Arc<dyn IndexReader>) -> Result<Self> {
1020 let batch = reader.read_range(0..reader.num_rows(), None).await?;
1021
1022 let (tokens, next_id, total_length) = spawn_blocking(move || {
1023 let mut next_id = 0;
1024 let mut total_length = 0;
1025 let mut tokens = fst::MapBuilder::memory();
1026
1027 let token_col = batch[TOKEN_COL].as_string::<i32>();
1028 let token_id_col = batch[TOKEN_ID_COL].as_primitive::<datatypes::UInt32Type>();
1029
1030 for (token, &token_id) in token_col.iter().zip(token_id_col.values().iter()) {
1031 let token = token.ok_or(Error::Index {
1032 message: "found null token in token set".to_owned(),
1033 location: location!(),
1034 })?;
1035 next_id = next_id.max(token_id + 1);
1036 total_length += token.len();
1037 tokens
1038 .insert(token, token_id as u64)
1039 .map_err(|e| Error::Index {
1040 message: format!("failed to insert token {}: {}", token, e),
1041 location: location!(),
1042 })?;
1043 }
1044
1045 Ok::<_, Error>((tokens.into_map(), next_id, total_length))
1046 })
1047 .await
1048 .map_err(|err| Error::Execution {
1049 message: format!("failed to spawn blocking task: {}", err),
1050 location: location!(),
1051 })??;
1052
1053 Ok(Self {
1054 tokens: TokenMap::Fst(tokens),
1055 next_id,
1056 total_length,
1057 })
1058 }
1059
1060 async fn load_fst(reader: Arc<dyn IndexReader>) -> Result<Self> {
1061 let batch = reader.read_range(0..reader.num_rows(), None).await?;
1062 if batch.num_rows() == 0 {
1063 return Err(Error::Index {
1064 message: "token set batch is empty".to_owned(),
1065 location: location!(),
1066 });
1067 }
1068
1069 let fst_col = batch[TOKEN_FST_BYTES_COL].as_binary::<i64>();
1070 let bytes = fst_col.value(0);
1071 let map = fst::Map::new(bytes.to_vec()).map_err(|e| Error::Index {
1072 message: format!("failed to load fst tokens: {}", e),
1073 location: location!(),
1074 })?;
1075
1076 let next_id_col = batch[TOKEN_NEXT_ID_COL].as_primitive::<datatypes::UInt32Type>();
1077 let total_length_col =
1078 batch[TOKEN_TOTAL_LENGTH_COL].as_primitive::<datatypes::UInt64Type>();
1079
1080 let next_id = next_id_col.values().first().copied().ok_or(Error::Index {
1081 message: "token next id column is empty".to_owned(),
1082 location: location!(),
1083 })?;
1084
1085 let total_length = total_length_col
1086 .values()
1087 .first()
1088 .copied()
1089 .ok_or(Error::Index {
1090 message: "token total length column is empty".to_owned(),
1091 location: location!(),
1092 })?;
1093
1094 Ok(Self {
1095 tokens: TokenMap::Fst(map),
1096 next_id,
1097 total_length: usize::try_from(total_length).map_err(|_| Error::Index {
1098 message: format!("token total length {} overflows usize", total_length),
1099 location: location!(),
1100 })?,
1101 })
1102 }
1103
1104 pub fn add(&mut self, token: String) -> u32 {
1105 let next_id = self.next_id();
1106 let len = token.len();
1107 let token_id = match self.tokens {
1108 TokenMap::HashMap(ref mut map) => *map.entry(token).or_insert(next_id),
1109 _ => unreachable!("tokens must be HashMap while indexing"),
1110 };
1111
1112 if token_id == next_id {
1114 self.next_id += 1;
1115 self.total_length += len;
1116 }
1117
1118 token_id
1119 }
1120
1121 pub fn get(&self, token: &str) -> Option<u32> {
1122 match self.tokens {
1123 TokenMap::HashMap(ref map) => map.get(token).copied(),
1124 TokenMap::Fst(ref map) => map.get(token).map(|id| id as u32),
1125 }
1126 }
1127
1128 pub fn remap(&mut self, removed_token_ids: &[u32]) {
1130 if removed_token_ids.is_empty() {
1131 return;
1132 }
1133
1134 let mut map = match std::mem::take(&mut self.tokens) {
1135 TokenMap::HashMap(map) => map,
1136 TokenMap::Fst(map) => {
1137 let mut new_map = HashMap::with_capacity(map.len());
1138 let mut stream = map.into_stream();
1139 while let Some((token, token_id)) = stream.next() {
1140 new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
1141 }
1142
1143 new_map
1144 }
1145 };
1146
1147 map.retain(
1148 |_, token_id| match removed_token_ids.binary_search(token_id) {
1149 Ok(_) => false,
1150 Err(index) => {
1151 *token_id -= index as u32;
1152 true
1153 }
1154 },
1155 );
1156
1157 self.tokens = TokenMap::HashMap(map);
1158 }
1159
1160 pub fn next_id(&self) -> u32 {
1161 self.next_id
1162 }
1163}
1164
1165pub struct PostingListReader {
1166 reader: Arc<dyn IndexReader>,
1167
1168 offsets: Option<Vec<usize>>,
1170
1171 max_scores: Option<Vec<f32>>,
1174
1175 lengths: Option<Vec<u32>>,
1177
1178 has_position: bool,
1179
1180 index_cache: WeakLanceCache,
1181}
1182
1183impl std::fmt::Debug for PostingListReader {
1184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1185 f.debug_struct("InvertedListReader")
1186 .field("offsets", &self.offsets)
1187 .field("max_scores", &self.max_scores)
1188 .finish()
1189 }
1190}
1191
1192impl DeepSizeOf for PostingListReader {
1193 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
1194 self.offsets.deep_size_of_children(context)
1195 + self.max_scores.deep_size_of_children(context)
1196 + self.lengths.deep_size_of_children(context)
1197 }
1198}
1199
1200impl PostingListReader {
1201 pub(crate) async fn try_new(
1202 reader: Arc<dyn IndexReader>,
1203 index_cache: &LanceCache,
1204 ) -> Result<Self> {
1205 let has_position = reader.schema().field(POSITION_COL).is_some();
1206 let (offsets, max_scores, lengths) = if reader.schema().field(POSTING_COL).is_none() {
1207 let (offsets, max_scores) = Self::load_metadata(reader.schema())?;
1208 (Some(offsets), max_scores, None)
1209 } else {
1210 let metadata = reader
1211 .read_range(0..reader.num_rows(), Some(&[MAX_SCORE_COL, LENGTH_COL]))
1212 .await?;
1213 let max_scores = metadata[MAX_SCORE_COL]
1214 .as_primitive::<Float32Type>()
1215 .values()
1216 .to_vec();
1217 let lengths = metadata[LENGTH_COL]
1218 .as_primitive::<UInt32Type>()
1219 .values()
1220 .to_vec();
1221 (None, Some(max_scores), Some(lengths))
1222 };
1223
1224 Ok(Self {
1225 reader,
1226 offsets,
1227 max_scores,
1228 lengths,
1229 has_position,
1230 index_cache: WeakLanceCache::from(index_cache),
1231 })
1232 }
1233
1234 fn load_metadata(
1237 schema: &lance_core::datatypes::Schema,
1238 ) -> Result<(Vec<usize>, Option<Vec<f32>>)> {
1239 let offsets = schema.metadata.get("offsets").ok_or(Error::Index {
1240 message: "offsets not found in metadata".to_owned(),
1241 location: location!(),
1242 })?;
1243 let offsets = serde_json::from_str(offsets)?;
1244
1245 let max_scores = schema
1246 .metadata
1247 .get("max_scores")
1248 .map(|max_scores| serde_json::from_str(max_scores))
1249 .transpose()?;
1250 Ok((offsets, max_scores))
1251 }
1252
1253 pub fn len(&self) -> usize {
1255 match self.offsets {
1256 Some(ref offsets) => offsets.len(),
1257 None => self.reader.num_rows(),
1258 }
1259 }
1260
1261 pub fn is_empty(&self) -> bool {
1262 self.len() == 0
1263 }
1264
1265 pub(crate) fn has_positions(&self) -> bool {
1266 self.has_position
1267 }
1268
1269 pub(crate) fn posting_len(&self, token_id: u32) -> usize {
1270 let token_id = token_id as usize;
1271
1272 match self.offsets {
1273 Some(ref offsets) => {
1274 let next_offset = offsets
1275 .get(token_id + 1)
1276 .copied()
1277 .unwrap_or(self.reader.num_rows());
1278 next_offset - offsets[token_id]
1279 }
1280 None => {
1281 if let Some(lengths) = &self.lengths {
1282 lengths[token_id] as usize
1283 } else {
1284 panic!("posting list reader is not initialized")
1285 }
1286 }
1287 }
1288 }
1289
1290 pub(crate) async fn posting_batch(
1291 &self,
1292 token_id: u32,
1293 with_position: bool,
1294 ) -> Result<RecordBatch> {
1295 if self.offsets.is_some() {
1296 self.posting_batch_legacy(token_id, with_position).await
1297 } else {
1298 let token_id = token_id as usize;
1299 let columns = if with_position {
1300 vec![POSTING_COL, POSITION_COL]
1301 } else {
1302 vec![POSTING_COL]
1303 };
1304 let batch = self
1305 .reader
1306 .read_range(token_id..token_id + 1, Some(&columns))
1307 .await?;
1308 Ok(batch)
1309 }
1310 }
1311
1312 async fn posting_batch_legacy(
1313 &self,
1314 token_id: u32,
1315 with_position: bool,
1316 ) -> Result<RecordBatch> {
1317 let mut columns = vec![ROW_ID, FREQUENCY_COL];
1318 if with_position {
1319 columns.push(POSITION_COL);
1320 }
1321
1322 let length = self.posting_len(token_id);
1323 let token_id = token_id as usize;
1324 let offset = self.offsets.as_ref().unwrap()[token_id];
1325 let batch = self
1326 .reader
1327 .read_range(offset..offset + length, Some(&columns))
1328 .await?;
1329 Ok(batch)
1330 }
1331
1332 #[instrument(level = "debug", skip(self, metrics))]
1333 pub(crate) async fn posting_list(
1334 &self,
1335 token_id: u32,
1336 is_phrase_query: bool,
1337 metrics: &dyn MetricsCollector,
1338 ) -> Result<PostingList> {
1339 let cache_key = PostingListKey { token_id };
1340 let mut posting = self
1341 .index_cache
1342 .get_or_insert_with_key(cache_key, || async move {
1343 metrics.record_part_load();
1344 info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="inverted", part_id=token_id);
1345 let batch = self.posting_batch(token_id, false).await?;
1346 self.posting_list_from_batch(&batch, token_id)
1347 })
1348 .await
1349 .map_err(|e| Error::io(e.to_string(), location!()))?
1350 .as_ref()
1351 .clone();
1352
1353 if is_phrase_query {
1354 let positions = self.read_positions(token_id).await?;
1356 posting.set_positions(positions);
1357 }
1358
1359 Ok(posting)
1360 }
1361
1362 pub(crate) fn posting_list_from_batch(
1363 &self,
1364 batch: &RecordBatch,
1365 token_id: u32,
1366 ) -> Result<PostingList> {
1367 let posting_list = PostingList::from_batch(
1368 batch,
1369 self.max_scores
1370 .as_ref()
1371 .map(|max_scores| max_scores[token_id as usize]),
1372 self.lengths
1373 .as_ref()
1374 .map(|lengths| lengths[token_id as usize]),
1375 )?;
1376 Ok(posting_list)
1377 }
1378
1379 async fn prewarm(&self) -> Result<()> {
1380 let batch = self.read_batch(false).await?;
1381 for token_id in 0..self.len() {
1382 let posting_range = self.posting_list_range(token_id as u32);
1383 let batch = batch.slice(posting_range.start, posting_range.end - posting_range.start);
1384 let batch = batch.shrink_to_fit()?;
1387 let posting_list = self.posting_list_from_batch(&batch, token_id as u32)?;
1388 let inserted = self
1389 .index_cache
1390 .insert_with_key(
1391 &PostingListKey {
1392 token_id: token_id as u32,
1393 },
1394 Arc::new(posting_list),
1395 )
1396 .await;
1397
1398 if !inserted {
1399 return Err(Error::Internal {
1400 message: "Failed to prewarm index: cache is no longer available".to_string(),
1401 location: location!(),
1402 });
1403 }
1404 }
1405
1406 Ok(())
1407 }
1408
1409 pub(crate) async fn read_batch(&self, with_position: bool) -> Result<RecordBatch> {
1410 let columns = self.posting_columns(with_position);
1411 let batch = self
1412 .reader
1413 .read_range(0..self.reader.num_rows(), Some(&columns))
1414 .await?;
1415 Ok(batch)
1416 }
1417
1418 pub(crate) async fn read_all(
1419 &self,
1420 with_position: bool,
1421 ) -> Result<impl Iterator<Item = Result<PostingList>> + '_> {
1422 let batch = self.read_batch(with_position).await?;
1423 Ok((0..self.len()).map(move |i| {
1424 let token_id = i as u32;
1425 let range = self.posting_list_range(token_id);
1426 let batch = batch.slice(i, range.end - range.start);
1427 self.posting_list_from_batch(&batch, token_id)
1428 }))
1429 }
1430
1431 async fn read_positions(&self, token_id: u32) -> Result<ListArray> {
1432 let positions = self.index_cache.get_or_insert_with_key(PositionKey { token_id }, || async move {
1433 let batch = self
1434 .reader
1435 .read_range(self.posting_list_range(token_id), Some(&[POSITION_COL]))
1436 .await.map_err(|e| {
1437 match e {
1438 Error::Schema { .. } => Error::invalid_input(
1439 "position is not found but required for phrase queries, try recreating the index with position".to_owned(),
1440 location!(),
1441 ),
1442 e => e
1443 }
1444 })?;
1445 Result::Ok(Positions(batch[POSITION_COL]
1446 .as_list::<i32>()
1447 .clone()))
1448 }).await?;
1449 Ok(positions.0.clone())
1450 }
1451
1452 fn posting_list_range(&self, token_id: u32) -> Range<usize> {
1453 match self.offsets {
1454 Some(ref offsets) => {
1455 let offset = offsets[token_id as usize];
1456 let posting_len = self.posting_len(token_id);
1457 offset..offset + posting_len
1458 }
1459 None => {
1460 let token_id = token_id as usize;
1461 token_id..token_id + 1
1462 }
1463 }
1464 }
1465
1466 fn posting_columns(&self, with_position: bool) -> Vec<&'static str> {
1467 let mut base_columns = match self.offsets {
1468 Some(_) => vec![ROW_ID, FREQUENCY_COL],
1469 None => vec![POSTING_COL],
1470 };
1471 if with_position {
1472 base_columns.push(POSITION_COL);
1473 }
1474 base_columns
1475 }
1476}
1477
1478#[derive(Clone)]
1481pub struct Positions(ListArray);
1482
1483impl DeepSizeOf for Positions {
1484 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1485 self.0.get_buffer_memory_size()
1486 }
1487}
1488
1489#[derive(Debug, Clone)]
1491pub struct PostingListKey {
1492 pub token_id: u32,
1493}
1494
1495impl CacheKey for PostingListKey {
1496 type ValueType = PostingList;
1497
1498 fn key(&self) -> std::borrow::Cow<'_, str> {
1499 format!("postings-{}", self.token_id).into()
1500 }
1501}
1502
1503#[derive(Debug, Clone)]
1504pub struct PositionKey {
1505 pub token_id: u32,
1506}
1507
1508impl CacheKey for PositionKey {
1509 type ValueType = Positions;
1510
1511 fn key(&self) -> std::borrow::Cow<'_, str> {
1512 format!("positions-{}", self.token_id).into()
1513 }
1514}
1515
1516#[derive(Debug, Clone, DeepSizeOf)]
1517pub enum PostingList {
1518 Plain(PlainPostingList),
1519 Compressed(CompressedPostingList),
1520}
1521
1522impl PostingList {
1523 pub fn from_batch(
1524 batch: &RecordBatch,
1525 max_score: Option<f32>,
1526 length: Option<u32>,
1527 ) -> Result<Self> {
1528 match batch.column_by_name(POSTING_COL) {
1529 Some(_) => {
1530 debug_assert!(max_score.is_some() && length.is_some());
1531 let posting =
1532 CompressedPostingList::from_batch(batch, max_score.unwrap(), length.unwrap());
1533 Ok(Self::Compressed(posting))
1534 }
1535 None => {
1536 let posting = PlainPostingList::from_batch(batch, max_score);
1537 Ok(Self::Plain(posting))
1538 }
1539 }
1540 }
1541
1542 pub fn iter(&self) -> PostingListIterator<'_> {
1543 PostingListIterator::new(self)
1544 }
1545
1546 pub fn has_position(&self) -> bool {
1547 match self {
1548 Self::Plain(posting) => posting.positions.is_some(),
1549 Self::Compressed(posting) => posting.positions.is_some(),
1550 }
1551 }
1552
1553 pub fn set_positions(&mut self, positions: ListArray) {
1554 match self {
1555 Self::Plain(posting) => posting.positions = Some(positions),
1556 Self::Compressed(posting) => {
1557 posting.positions = Some(positions.value(0).as_list::<i32>().clone());
1558 }
1559 }
1560 }
1561
1562 pub fn max_score(&self) -> Option<f32> {
1563 match self {
1564 Self::Plain(posting) => posting.max_score,
1565 Self::Compressed(posting) => Some(posting.max_score),
1566 }
1567 }
1568
1569 pub fn len(&self) -> usize {
1570 match self {
1571 Self::Plain(posting) => posting.len(),
1572 Self::Compressed(posting) => posting.length as usize,
1573 }
1574 }
1575
1576 pub fn is_empty(&self) -> bool {
1577 self.len() == 0
1578 }
1579
1580 pub fn into_builder(self, docs: &DocSet) -> PostingListBuilder {
1581 let mut builder = PostingListBuilder::new(self.has_position());
1582 match self {
1583 Self::Plain(posting) => {
1585 struct Item {
1589 doc_id: u32,
1590 positions: PositionRecorder,
1591 }
1592 let doc_ids = docs
1593 .row_ids
1594 .iter()
1595 .enumerate()
1596 .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
1597 .collect::<HashMap<_, _>>();
1598 let mut items = Vec::with_capacity(posting.len());
1599 for (row_id, freq, positions) in posting.iter() {
1600 let freq = freq as u32;
1601 let positions = match positions {
1602 Some(positions) => {
1603 PositionRecorder::Position(positions.collect::<Vec<_>>())
1604 }
1605 None => PositionRecorder::Count(freq),
1606 };
1607 items.push(Item {
1608 doc_id: doc_ids[&row_id],
1609 positions,
1610 });
1611 }
1612 items.sort_unstable_by_key(|item| item.doc_id);
1613 for item in items {
1614 builder.add(item.doc_id, item.positions);
1615 }
1616 }
1617 Self::Compressed(posting) => {
1618 posting.iter().for_each(|(doc_id, freq, positions)| {
1619 let positions = match positions {
1620 Some(positions) => {
1621 PositionRecorder::Position(positions.collect::<Vec<_>>())
1622 }
1623 None => PositionRecorder::Count(freq),
1624 };
1625 builder.add(doc_id, positions);
1626 });
1627 }
1628 }
1629 builder
1630 }
1631}
1632
1633#[derive(Debug, PartialEq, Clone)]
1634pub struct PlainPostingList {
1635 pub row_ids: ScalarBuffer<u64>,
1636 pub frequencies: ScalarBuffer<f32>,
1637 pub max_score: Option<f32>,
1638 pub positions: Option<ListArray>, }
1640
1641impl DeepSizeOf for PlainPostingList {
1642 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1643 self.row_ids.len() * std::mem::size_of::<u64>()
1644 + self.frequencies.len() * std::mem::size_of::<u32>()
1645 + self
1646 .positions
1647 .as_ref()
1648 .map(|positions| positions.get_buffer_memory_size())
1649 .unwrap_or(0)
1650 }
1651}
1652
1653impl PlainPostingList {
1654 pub fn new(
1655 row_ids: ScalarBuffer<u64>,
1656 frequencies: ScalarBuffer<f32>,
1657 max_score: Option<f32>,
1658 positions: Option<ListArray>,
1659 ) -> Self {
1660 Self {
1661 row_ids,
1662 frequencies,
1663 max_score,
1664 positions,
1665 }
1666 }
1667
1668 pub fn from_batch(batch: &RecordBatch, max_score: Option<f32>) -> Self {
1669 let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().values().clone();
1670 let frequencies = batch[FREQUENCY_COL]
1671 .as_primitive::<Float32Type>()
1672 .values()
1673 .clone();
1674 let positions = batch
1675 .column_by_name(POSITION_COL)
1676 .map(|col| col.as_list::<i32>().clone());
1677
1678 Self::new(row_ids, frequencies, max_score, positions)
1679 }
1680
1681 pub fn len(&self) -> usize {
1682 self.row_ids.len()
1683 }
1684
1685 pub fn is_empty(&self) -> bool {
1686 self.len() == 0
1687 }
1688
1689 pub fn iter(&self) -> PlainPostingListIterator<'_> {
1690 Box::new(
1691 self.row_ids
1692 .iter()
1693 .zip(self.frequencies.iter())
1694 .enumerate()
1695 .map(|(idx, (doc_id, freq))| {
1696 (
1697 *doc_id,
1698 *freq,
1699 self.positions.as_ref().map(|p| {
1700 let start = p.value_offsets()[idx] as usize;
1701 let end = p.value_offsets()[idx + 1] as usize;
1702 Box::new(
1703 p.values().as_primitive::<Int32Type>().values()[start..end]
1704 .iter()
1705 .map(|pos| *pos as u32),
1706 ) as _
1707 }),
1708 )
1709 }),
1710 )
1711 }
1712
1713 #[inline]
1714 pub fn doc(&self, i: usize) -> LocatedDocInfo {
1715 LocatedDocInfo::new(self.row_ids[i], self.frequencies[i])
1716 }
1717
1718 pub fn positions(&self, index: usize) -> Option<Arc<dyn Array>> {
1719 self.positions
1720 .as_ref()
1721 .map(|positions| positions.value(index))
1722 }
1723
1724 pub fn max_score(&self) -> Option<f32> {
1725 self.max_score
1726 }
1727
1728 pub fn row_id(&self, i: usize) -> u64 {
1729 self.row_ids[i]
1730 }
1731}
1732
1733#[derive(Debug, PartialEq, Clone)]
1734pub struct CompressedPostingList {
1735 pub max_score: f32,
1736 pub length: u32,
1737 pub blocks: LargeBinaryArray,
1740 pub positions: Option<ListArray>,
1741}
1742
1743impl DeepSizeOf for CompressedPostingList {
1744 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
1745 self.blocks.get_buffer_memory_size()
1746 + self
1747 .positions
1748 .as_ref()
1749 .map(|positions| positions.get_buffer_memory_size())
1750 .unwrap_or(0)
1751 }
1752}
1753
1754impl CompressedPostingList {
1755 pub fn new(
1756 blocks: LargeBinaryArray,
1757 max_score: f32,
1758 length: u32,
1759 positions: Option<ListArray>,
1760 ) -> Self {
1761 Self {
1762 max_score,
1763 length,
1764 blocks,
1765 positions,
1766 }
1767 }
1768
1769 pub fn from_batch(batch: &RecordBatch, max_score: f32, length: u32) -> Self {
1770 debug_assert_eq!(batch.num_rows(), 1);
1771 let blocks = batch[POSTING_COL]
1772 .as_list::<i32>()
1773 .value(0)
1774 .as_binary::<i64>()
1775 .clone();
1776 let positions = batch
1777 .column_by_name(POSITION_COL)
1778 .map(|col| col.as_list::<i32>().value(0).as_list::<i32>().clone());
1779
1780 Self {
1781 max_score,
1782 length,
1783 blocks,
1784 positions,
1785 }
1786 }
1787
1788 pub fn iter(&self) -> CompressedPostingListIterator {
1789 CompressedPostingListIterator::new(
1790 self.length as usize,
1791 self.blocks.clone(),
1792 self.positions.clone(),
1793 )
1794 }
1795
1796 pub fn block_max_score(&self, block_idx: usize) -> f32 {
1797 let block = self.blocks.value(block_idx);
1798 block[0..4].try_into().map(f32::from_le_bytes).unwrap()
1799 }
1800
1801 pub fn block_least_doc_id(&self, block_idx: usize) -> u32 {
1802 let block = self.blocks.value(block_idx);
1803 block[4..8].try_into().map(u32::from_le_bytes).unwrap()
1804 }
1805}
1806
1807#[derive(Debug)]
1808pub struct PostingListBuilder {
1809 pub doc_ids: ExpLinkedList<u32>,
1810 pub frequencies: ExpLinkedList<u32>,
1811 pub positions: Option<PositionBuilder>,
1812}
1813
1814impl PostingListBuilder {
1815 pub fn size(&self) -> u64 {
1816 (std::mem::size_of::<u32>() * self.doc_ids.len()
1817 + std::mem::size_of::<u32>() * self.frequencies.len()
1818 + self
1819 .positions
1820 .as_ref()
1821 .map(|positions| positions.size())
1822 .unwrap_or(0)) as u64
1823 }
1824
1825 pub fn has_positions(&self) -> bool {
1826 self.positions.is_some()
1827 }
1828
1829 pub fn new(with_position: bool) -> Self {
1830 Self {
1831 doc_ids: ExpLinkedList::new().with_capacity_limit(128),
1832 frequencies: ExpLinkedList::new().with_capacity_limit(128),
1833 positions: with_position.then(PositionBuilder::new),
1834 }
1835 }
1836
1837 pub fn len(&self) -> usize {
1838 self.doc_ids.len()
1839 }
1840
1841 pub fn is_empty(&self) -> bool {
1842 self.len() == 0
1843 }
1844
1845 pub fn iter(&self) -> impl Iterator<Item = (&u32, &u32, Option<&[u32]>)> {
1846 self.doc_ids
1847 .iter()
1848 .zip(self.frequencies.iter())
1849 .enumerate()
1850 .map(|(idx, (doc_id, freq))| {
1851 let positions = self.positions.as_ref().map(|positions| positions.get(idx));
1852 (doc_id, freq, positions)
1853 })
1854 }
1855
1856 pub fn add(&mut self, doc_id: u32, term_positions: PositionRecorder) {
1857 self.doc_ids.push(doc_id);
1858 self.frequencies.push(term_positions.len());
1859 if let Some(positions) = self.positions.as_mut() {
1860 positions.push(term_positions.into_vec());
1861 }
1862 }
1863
1864 pub fn to_batch(self, block_max_scores: Vec<f32>) -> Result<RecordBatch> {
1866 let length = self.len();
1867 let max_score = block_max_scores.iter().copied().fold(f32::MIN, f32::max);
1868
1869 let schema = inverted_list_schema(self.has_positions());
1870 let compressed = compress_posting_list(
1871 self.doc_ids.len(),
1872 self.doc_ids.iter(),
1873 self.frequencies.iter(),
1874 block_max_scores.into_iter(),
1875 )?;
1876 let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, compressed.len() as i32]));
1877 let mut columns = vec![
1878 Arc::new(ListArray::try_new(
1879 Arc::new(Field::new("item", datatypes::DataType::LargeBinary, true)),
1880 offsets,
1881 Arc::new(compressed),
1882 None,
1883 )?) as ArrayRef,
1884 Arc::new(Float32Array::from_iter_values(std::iter::once(max_score))) as ArrayRef,
1885 Arc::new(UInt32Array::from_iter_values(std::iter::once(
1886 self.len() as u32
1887 ))) as ArrayRef,
1888 ];
1889
1890 if let Some(positions) = self.positions.as_ref() {
1891 let mut position_builder = ListBuilder::new(ListBuilder::with_capacity(
1892 LargeBinaryBuilder::new(),
1893 length,
1894 ));
1895 for index in 0..length {
1896 let positions_in_doc = positions.get(index);
1897 let compressed = compress_positions(positions_in_doc)?;
1898 let inner_builder = position_builder.values();
1899 inner_builder.append_value(compressed.into_iter());
1900 }
1901 position_builder.append(true);
1902 let position_col = position_builder.finish();
1903 columns.push(Arc::new(position_col));
1904 }
1905
1906 let batch = RecordBatch::try_new(schema, columns)?;
1907 Ok(batch)
1908 }
1909
1910 pub fn remap(&mut self, removed: &[u32]) {
1911 let mut cursor = 0;
1912 let mut new_doc_ids = ExpLinkedList::with_capacity(self.len());
1913 let mut new_frequencies = ExpLinkedList::with_capacity(self.len());
1914 let mut new_positions = self.positions.as_mut().map(|_| PositionBuilder::new());
1915 for (&doc_id, &freq, positions) in self.iter() {
1916 while cursor < removed.len() && removed[cursor] < doc_id {
1917 cursor += 1;
1918 }
1919 if cursor < removed.len() && removed[cursor] == doc_id {
1920 continue;
1922 }
1923 new_doc_ids.push(doc_id - cursor as u32);
1926 new_frequencies.push(freq);
1927 if let Some(new_positions) = new_positions.as_mut() {
1928 new_positions.push(positions.unwrap().to_vec());
1929 }
1930 }
1931
1932 self.doc_ids = new_doc_ids;
1933 self.frequencies = new_frequencies;
1934 self.positions = new_positions;
1935 }
1936}
1937
1938#[derive(Debug, Clone, DeepSizeOf)]
1939pub struct PositionBuilder {
1940 positions: Vec<u32>,
1941 offsets: Vec<i32>,
1942}
1943
1944impl Default for PositionBuilder {
1945 fn default() -> Self {
1946 Self::new()
1947 }
1948}
1949
1950impl PositionBuilder {
1951 pub fn new() -> Self {
1952 Self {
1953 positions: Vec::new(),
1954 offsets: vec![0],
1955 }
1956 }
1957
1958 pub fn size(&self) -> usize {
1959 std::mem::size_of::<u32>() * self.positions.len()
1960 + std::mem::size_of::<i32>() * self.offsets.len()
1961 }
1962
1963 pub fn total_len(&self) -> usize {
1964 self.positions.len()
1965 }
1966
1967 pub fn push(&mut self, positions: Vec<u32>) {
1968 self.positions.extend(positions);
1969 self.offsets.push(self.positions.len() as i32);
1970 }
1971
1972 pub fn get(&self, i: usize) -> &[u32] {
1973 let start = self.offsets[i] as usize;
1974 let end = self.offsets[i + 1] as usize;
1975 &self.positions[start..end]
1976 }
1977}
1978
1979impl From<Vec<Vec<u32>>> for PositionBuilder {
1980 fn from(positions: Vec<Vec<u32>>) -> Self {
1981 let mut builder = Self::new();
1982 builder.offsets.reserve(positions.len());
1983 for pos in positions {
1984 builder.push(pos);
1985 }
1986 builder
1987 }
1988}
1989
1990#[derive(Debug, Clone, DeepSizeOf, Copy)]
1991pub enum DocInfo {
1992 Located(LocatedDocInfo),
1993 Raw(RawDocInfo),
1994}
1995
1996impl DocInfo {
1997 pub fn doc_id(&self) -> u64 {
1998 match self {
1999 Self::Raw(info) => info.doc_id as u64,
2000 Self::Located(info) => info.row_id,
2001 }
2002 }
2003
2004 pub fn frequency(&self) -> u32 {
2005 match self {
2006 Self::Raw(info) => info.frequency,
2007 Self::Located(info) => info.frequency as u32,
2008 }
2009 }
2010}
2011
2012impl Eq for DocInfo {}
2013
2014impl PartialEq for DocInfo {
2015 fn eq(&self, other: &Self) -> bool {
2016 self.doc_id() == other.doc_id()
2017 }
2018}
2019
2020impl PartialOrd for DocInfo {
2021 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2022 Some(self.cmp(other))
2023 }
2024}
2025
2026impl Ord for DocInfo {
2027 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2028 self.doc_id().cmp(&other.doc_id())
2029 }
2030}
2031
2032#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
2033pub struct LocatedDocInfo {
2034 pub row_id: u64,
2035 pub frequency: f32,
2036}
2037
2038impl LocatedDocInfo {
2039 pub fn new(row_id: u64, frequency: f32) -> Self {
2040 Self { row_id, frequency }
2041 }
2042}
2043
2044impl Eq for LocatedDocInfo {}
2045
2046impl PartialEq for LocatedDocInfo {
2047 fn eq(&self, other: &Self) -> bool {
2048 self.row_id == other.row_id
2049 }
2050}
2051
2052impl PartialOrd for LocatedDocInfo {
2053 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2054 Some(self.cmp(other))
2055 }
2056}
2057
2058impl Ord for LocatedDocInfo {
2059 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2060 self.row_id.cmp(&other.row_id)
2061 }
2062}
2063
2064#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
2065pub struct RawDocInfo {
2066 pub doc_id: u32,
2067 pub frequency: u32,
2068}
2069
2070impl RawDocInfo {
2071 pub fn new(doc_id: u32, frequency: u32) -> Self {
2072 Self { doc_id, frequency }
2073 }
2074}
2075
2076impl Eq for RawDocInfo {}
2077
2078impl PartialEq for RawDocInfo {
2079 fn eq(&self, other: &Self) -> bool {
2080 self.doc_id == other.doc_id
2081 }
2082}
2083
2084impl PartialOrd for RawDocInfo {
2085 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
2086 Some(self.cmp(other))
2087 }
2088}
2089
2090impl Ord for RawDocInfo {
2091 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
2092 self.doc_id.cmp(&other.doc_id)
2093 }
2094}
2095
2096#[derive(Debug, Clone, Default, DeepSizeOf)]
2099pub struct DocSet {
2100 row_ids: Vec<u64>,
2101 num_tokens: Vec<u32>,
2102 inv: Vec<(u64, u32)>,
2104
2105 total_tokens: u64,
2106}
2107
2108impl DocSet {
2109 #[inline]
2110 pub fn len(&self) -> usize {
2111 self.row_ids.len()
2112 }
2113
2114 pub fn is_empty(&self) -> bool {
2115 self.len() == 0
2116 }
2117
2118 pub fn iter(&self) -> impl Iterator<Item = (&u64, &u32)> {
2119 self.row_ids.iter().zip(self.num_tokens.iter())
2120 }
2121
2122 pub fn row_id(&self, doc_id: u32) -> u64 {
2123 self.row_ids[doc_id as usize]
2124 }
2125
2126 pub fn doc_id(&self, row_id: u64) -> Option<u64> {
2127 if self.inv.is_empty() {
2128 match self.row_ids.binary_search(&row_id) {
2130 Ok(_) => Some(row_id),
2131 Err(_) => None,
2132 }
2133 } else {
2134 match self.inv.binary_search_by_key(&row_id, |x| x.0) {
2135 Ok(idx) => Some(self.inv[idx].1 as u64),
2136 Err(_) => None,
2137 }
2138 }
2139 }
2140 pub fn total_tokens_num(&self) -> u64 {
2141 self.total_tokens
2142 }
2143
2144 #[inline]
2145 pub fn average_length(&self) -> f32 {
2146 self.total_tokens as f32 / self.len() as f32
2147 }
2148
2149 pub fn calculate_block_max_scores<'a>(
2150 &self,
2151 doc_ids: impl Iterator<Item = &'a u32>,
2152 freqs: impl Iterator<Item = &'a u32>,
2153 ) -> Vec<f32> {
2154 let avgdl = self.average_length();
2155 let length = doc_ids.size_hint().0;
2156 let mut block_max_scores = Vec::with_capacity(length);
2157 let mut max_score = f32::MIN;
2158 for (i, (doc_id, freq)) in doc_ids.zip(freqs).enumerate() {
2159 let doc_norm = K1 * (1.0 - B + B * self.num_tokens(*doc_id) as f32 / avgdl);
2160 let freq = *freq as f32;
2161 let score = freq / (freq + doc_norm);
2162 if score > max_score {
2163 max_score = score;
2164 }
2165 if (i + 1) % BLOCK_SIZE == 0 {
2166 max_score *= idf(length, self.len()) * (K1 + 1.0);
2167 block_max_scores.push(max_score);
2168 max_score = f32::MIN;
2169 }
2170 }
2171 if length % BLOCK_SIZE > 0 {
2172 max_score *= idf(length, self.len()) * (K1 + 1.0);
2173 block_max_scores.push(max_score);
2174 }
2175 block_max_scores
2176 }
2177
2178 pub fn to_batch(&self) -> Result<RecordBatch> {
2179 let row_id_col = UInt64Array::from_iter_values(self.row_ids.iter().cloned());
2180 let num_tokens_col = UInt32Array::from_iter_values(self.num_tokens.iter().cloned());
2181
2182 let schema = arrow_schema::Schema::new(vec![
2183 arrow_schema::Field::new(ROW_ID, DataType::UInt64, false),
2184 arrow_schema::Field::new(NUM_TOKEN_COL, DataType::UInt32, false),
2185 ]);
2186
2187 let batch = RecordBatch::try_new(
2188 Arc::new(schema),
2189 vec![
2190 Arc::new(row_id_col) as ArrayRef,
2191 Arc::new(num_tokens_col) as ArrayRef,
2192 ],
2193 )?;
2194 Ok(batch)
2195 }
2196
2197 pub async fn load(
2198 reader: Arc<dyn IndexReader>,
2199 is_legacy: bool,
2200 frag_reuse_index: Option<Arc<FragReuseIndex>>,
2201 ) -> Result<Self> {
2202 let batch = reader.read_range(0..reader.num_rows(), None).await?;
2203 let row_id_col = batch[ROW_ID].as_primitive::<datatypes::UInt64Type>();
2204 let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::<datatypes::UInt32Type>();
2205
2206 if is_legacy {
2208 let (row_ids, num_tokens): (Vec<_>, Vec<_>) = row_id_col
2209 .values()
2210 .iter()
2211 .filter_map(|id| {
2212 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
2213 frag_reuse_index_ref.remap_row_id(*id)
2214 } else {
2215 Some(*id)
2216 }
2217 })
2218 .zip(num_tokens_col.values().iter())
2219 .sorted_unstable_by_key(|x| x.0)
2220 .unzip();
2221
2222 let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2223 return Ok(Self {
2224 row_ids,
2225 num_tokens,
2226 inv: Vec::new(),
2227 total_tokens,
2228 });
2229 }
2230
2231 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
2234 let mut row_ids = Vec::with_capacity(row_id_col.len());
2235 let mut num_tokens = Vec::with_capacity(num_tokens_col.len());
2236 for (row_id, num_token) in row_id_col.values().iter().zip(num_tokens_col.values()) {
2237 if let Some(new_row_id) = frag_reuse_index_ref.remap_row_id(*row_id) {
2238 row_ids.push(new_row_id);
2239 num_tokens.push(*num_token);
2240 }
2241 }
2242
2243 let mut inv: Vec<(u64, u32)> = row_ids
2244 .iter()
2245 .enumerate()
2246 .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
2247 .collect();
2248 inv.sort_unstable_by_key(|entry| entry.0);
2249
2250 let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2251 return Ok(Self {
2252 row_ids,
2253 num_tokens,
2254 inv,
2255 total_tokens,
2256 });
2257 }
2258
2259 let row_ids = row_id_col.values().to_vec();
2260 let num_tokens = num_tokens_col.values().to_vec();
2261 let mut inv: Vec<(u64, u32)> = row_ids
2262 .iter()
2263 .enumerate()
2264 .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
2265 .collect();
2266 if !row_ids.is_sorted() {
2267 inv.sort_unstable_by_key(|entry| entry.0);
2268 }
2269 let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
2270 Ok(Self {
2271 row_ids,
2272 num_tokens,
2273 inv,
2274 total_tokens,
2275 })
2276 }
2277
2278 pub fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Vec<u32> {
2281 let mut removed = Vec::new();
2282 let len = self.len();
2283 let row_ids = std::mem::replace(&mut self.row_ids, Vec::with_capacity(len));
2284 let num_tokens = std::mem::replace(&mut self.num_tokens, Vec::with_capacity(len));
2285 for (doc_id, (row_id, num_token)) in std::iter::zip(row_ids, num_tokens).enumerate() {
2286 match mapping.get(&row_id) {
2287 Some(Some(new_row_id)) => {
2288 self.row_ids.push(*new_row_id);
2289 self.num_tokens.push(num_token);
2290 }
2291 Some(None) => {
2292 removed.push(doc_id as u32);
2293 }
2294 None => {
2295 self.row_ids.push(row_id);
2296 self.num_tokens.push(num_token);
2297 }
2298 }
2299 }
2300 removed
2301 }
2302
2303 #[inline]
2304 pub fn num_tokens(&self, doc_id: u32) -> u32 {
2305 self.num_tokens[doc_id as usize]
2306 }
2307
2308 #[inline]
2311 pub fn num_tokens_by_row_id(&self, row_id: u64) -> u32 {
2312 self.row_ids
2313 .binary_search(&row_id)
2314 .map(|idx| self.num_tokens[idx])
2315 .unwrap_or(0)
2316 }
2317
2318 pub fn append(&mut self, row_id: u64, num_tokens: u32) -> u32 {
2321 self.row_ids.push(row_id);
2322 self.num_tokens.push(num_tokens);
2323 self.total_tokens += num_tokens as u64;
2324 self.row_ids.len() as u32 - 1
2325 }
2326}
2327
2328pub fn flat_full_text_search(
2329 batches: &[&RecordBatch],
2330 doc_col: &str,
2331 query: &str,
2332 tokenizer: Option<Box<dyn LanceTokenizer>>,
2333) -> Result<Vec<u64>> {
2334 if batches.is_empty() {
2335 return Ok(vec![]);
2336 }
2337
2338 if is_phrase_query(query) {
2339 return Err(Error::invalid_input(
2340 "phrase query is not supported for flat full text search, try using FTS index",
2341 location!(),
2342 ));
2343 }
2344
2345 match batches[0][doc_col].data_type() {
2346 DataType::Utf8 => do_flat_full_text_search::<i32>(batches, doc_col, query, tokenizer),
2347 DataType::LargeUtf8 => do_flat_full_text_search::<i64>(batches, doc_col, query, tokenizer),
2348 data_type => Err(Error::invalid_input(
2349 format!("unsupported data type {} for inverted index", data_type),
2350 location!(),
2351 )),
2352 }
2353}
2354
2355fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
2356 batches: &[&RecordBatch],
2357 doc_col: &str,
2358 query: &str,
2359 tokenizer: Option<Box<dyn LanceTokenizer>>,
2360) -> Result<Vec<u64>> {
2361 let mut results = Vec::new();
2362 let mut tokenizer =
2363 tokenizer.unwrap_or_else(|| InvertedIndexParams::default().build().unwrap());
2364 let query_tokens = collect_query_tokens(query, &mut tokenizer, None);
2365
2366 for batch in batches {
2367 let row_id_array = batch[ROW_ID].as_primitive::<UInt64Type>();
2368 let doc_array = batch[doc_col].as_string::<Offset>();
2369 for i in 0..row_id_array.len() {
2370 let doc = doc_array.value(i);
2371 let doc_tokens = collect_doc_tokens(doc, &mut tokenizer, Some(&query_tokens));
2372 if !doc_tokens.is_empty() {
2373 results.push(row_id_array.value(i));
2374 assert!(doc.contains(query));
2375 }
2376 }
2377 }
2378
2379 Ok(results)
2380}
2381
2382#[allow(clippy::too_many_arguments)]
2383pub fn flat_bm25_search(
2384 batch: RecordBatch,
2385 doc_col: &str,
2386 query_tokens: &Tokens,
2387 tokenizer: &mut Box<dyn LanceTokenizer>,
2388 scorer: &mut MemBM25Scorer,
2389) -> std::result::Result<RecordBatch, DataFusionError> {
2390 let doc_iter = iter_str_array(&batch[doc_col]);
2391 let mut scores = Vec::with_capacity(batch.num_rows());
2392 for doc in doc_iter {
2393 let Some(doc) = doc else {
2394 scores.push(0.0);
2395 continue;
2396 };
2397
2398 let doc_tokens = collect_doc_tokens(doc, tokenizer, None);
2399 scorer.update(&doc_tokens);
2400 let doc_tokens = doc_tokens
2401 .into_iter()
2402 .filter(|t| query_tokens.contains(t))
2403 .collect::<Vec<_>>();
2404
2405 let doc_norm = K1 * (1.0 - B + B * doc_tokens.len() as f32 / scorer.avg_doc_length());
2406 let mut doc_token_count = HashMap::new();
2407 for token in doc_tokens {
2408 doc_token_count
2409 .entry(token)
2410 .and_modify(|count| *count += 1)
2411 .or_insert(1);
2412 }
2413 let mut score = 0.0;
2414 for token in query_tokens {
2415 let freq = doc_token_count.get(token).copied().unwrap_or_default() as f32;
2416
2417 let idf = idf(scorer.num_docs_containing_token(token), scorer.num_docs());
2418 score += idf * (freq * (K1 + 1.0) / (freq + doc_norm));
2419 }
2420 scores.push(score);
2421 }
2422
2423 let score_col = Arc::new(Float32Array::from(scores)) as ArrayRef;
2424 let batch = batch
2425 .try_with_column(SCORE_FIELD.clone(), score_col)?
2426 .project_by_schema(&FTS_SCHEMA)?; Ok(batch)
2428}
2429
2430pub fn flat_bm25_search_stream(
2431 input: SendableRecordBatchStream,
2432 doc_col: String,
2433 query: String,
2434 index: &Option<InvertedIndex>,
2435) -> SendableRecordBatchStream {
2436 let mut tokenizer = match index {
2437 Some(index) => index.tokenizer(),
2438 None => Box::new(TextTokenizer::new(
2439 tantivy::tokenizer::TextAnalyzer::builder(
2440 tantivy::tokenizer::SimpleTokenizer::default(),
2441 )
2442 .build(),
2443 )),
2444 };
2445 let tokens = collect_query_tokens(&query, &mut tokenizer, None);
2446
2447 let mut bm25_scorer = match index {
2448 Some(index) => {
2449 let index_bm25_scorer =
2450 IndexBM25Scorer::new(index.partitions.iter().map(|p| p.as_ref()));
2451 if index_bm25_scorer.num_docs() == 0 {
2452 MemBM25Scorer::new(0, 0, HashMap::new())
2453 } else {
2454 let mut token_docs = HashMap::with_capacity(tokens.len());
2455 for token in &tokens {
2456 let token_nq = index_bm25_scorer.num_docs_containing_token(token).max(1);
2457 token_docs.insert(token.clone(), token_nq);
2458 }
2459 MemBM25Scorer::new(
2460 index_bm25_scorer.avg_doc_length() as u64 * index_bm25_scorer.num_docs() as u64,
2461 index_bm25_scorer.num_docs(),
2462 token_docs,
2463 )
2464 }
2465 }
2466 None => MemBM25Scorer::new(0, 0, HashMap::new()),
2467 };
2468
2469 let stream = input.map(move |batch| {
2470 let batch = batch?;
2471
2472 let batch = flat_bm25_search(batch, &doc_col, &tokens, &mut tokenizer, &mut bm25_scorer)?;
2473
2474 let score_col = batch[SCORE_COL].as_primitive::<Float32Type>();
2476 let mask = score_col
2477 .iter()
2478 .map(|score| score.is_some_and(|score| score > 0.0))
2479 .collect::<Vec<_>>();
2480 let mask = BooleanArray::from(mask);
2481 let batch = arrow::compute::filter_record_batch(&batch, &mask)?;
2482 debug_assert!(batch[ROW_ID].null_count() == 0, "flat FTS produces nulls");
2483 Ok(batch)
2484 });
2485
2486 Box::pin(RecordBatchStreamAdapter::new(FTS_SCHEMA.clone(), stream)) as SendableRecordBatchStream
2487}
2488
2489pub fn is_phrase_query(query: &str) -> bool {
2490 query.starts_with('\"') && query.ends_with('\"')
2491}
2492
2493#[cfg(test)]
2494mod tests {
2495 use crate::scalar::inverted::lance_tokenizer::DocType;
2496 use lance_core::cache::LanceCache;
2497 use lance_core::utils::tempfile::TempObjDir;
2498 use lance_io::object_store::ObjectStore;
2499
2500 use crate::metrics::NoOpMetricsCollector;
2501 use crate::prefilter::NoFilter;
2502 use crate::scalar::inverted::builder::{InnerBuilder, PositionRecorder};
2503 use crate::scalar::inverted::encoding::decompress_posting_list;
2504 use crate::scalar::inverted::query::{FtsSearchParams, Operator};
2505 use crate::scalar::lance_format::LanceIndexStore;
2506
2507 use super::*;
2508
2509 #[tokio::test]
2510 async fn test_posting_builder_remap() {
2511 let mut builder = PostingListBuilder::new(false);
2512 let n = BLOCK_SIZE + 3;
2513 for i in 0..n {
2514 builder.add(i as u32, PositionRecorder::Count(1));
2515 }
2516 let removed = vec![5, 7];
2517 builder.remap(&removed);
2518
2519 let mut expected = PostingListBuilder::new(false);
2520 for i in 0..n - removed.len() {
2521 expected.add(i as u32, PositionRecorder::Count(1));
2522 }
2523 assert_eq!(builder.doc_ids, expected.doc_ids);
2524 assert_eq!(builder.frequencies, expected.frequencies);
2525
2526 let batch = builder.to_batch(vec![1.0, 2.0]).unwrap();
2529 let (doc_ids, freqs) = decompress_posting_list(
2530 (n - removed.len()) as u32,
2531 batch[POSTING_COL]
2532 .as_list::<i32>()
2533 .value(0)
2534 .as_binary::<i64>(),
2535 )
2536 .unwrap();
2537 assert!(doc_ids
2538 .iter()
2539 .zip(expected.doc_ids.iter())
2540 .all(|(a, b)| a == b));
2541 assert!(freqs
2542 .iter()
2543 .zip(expected.frequencies.iter())
2544 .all(|(a, b)| a == b));
2545 }
2546
2547 #[tokio::test]
2548 async fn test_remap_to_empty_posting_list() {
2549 let tmpdir = TempObjDir::default();
2550 let store = Arc::new(LanceIndexStore::new(
2551 ObjectStore::local().into(),
2552 tmpdir.clone(),
2553 Arc::new(LanceCache::no_cache()),
2554 ));
2555
2556 let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default());
2557
2558 builder.tokens.add("lance".to_owned());
2563 builder.tokens.add("lake".to_owned());
2564 builder.posting_lists.push(PostingListBuilder::new(false));
2565 builder.posting_lists.push(PostingListBuilder::new(false));
2566 builder.posting_lists[0].add(0, PositionRecorder::Count(1));
2567 builder.posting_lists[1].add(1, PositionRecorder::Count(2));
2568 builder.posting_lists[1].add(2, PositionRecorder::Count(3));
2569 builder.docs.append(0, 1);
2570 builder.docs.append(1, 1);
2571 builder.docs.append(2, 1);
2572 builder.write(store.as_ref()).await.unwrap();
2573
2574 let index = InvertedPartition::load(
2575 store.clone(),
2576 0,
2577 None,
2578 &LanceCache::no_cache(),
2579 TokenSetFormat::default(),
2580 )
2581 .await
2582 .unwrap();
2583 let mut builder = index.into_builder().await.unwrap();
2584
2585 let mapping = HashMap::from([(0, None), (2, Some(3))]);
2586 builder.remap(&mapping).await.unwrap();
2587
2588 assert_eq!(builder.tokens.len(), 1);
2590 assert_eq!(builder.tokens.get("lake"), Some(0));
2591 assert_eq!(builder.posting_lists.len(), 1);
2592 assert_eq!(builder.posting_lists[0].len(), 2);
2593 assert_eq!(builder.docs.len(), 2);
2594 assert_eq!(builder.docs.row_id(0), 1);
2595 assert_eq!(builder.docs.row_id(1), 3);
2596
2597 builder.write(store.as_ref()).await.unwrap();
2598
2599 let mapping = HashMap::from([(1, None), (3, None)]);
2601 builder.remap(&mapping).await.unwrap();
2602
2603 assert_eq!(builder.tokens.len(), 0);
2604 assert_eq!(builder.posting_lists.len(), 0);
2605 assert_eq!(builder.docs.len(), 0);
2606
2607 builder.write(store.as_ref()).await.unwrap();
2608 }
2609
2610 #[tokio::test]
2611 async fn test_posting_cache_conflict_across_partitions() {
2612 let tmpdir = TempObjDir::default();
2613 let store = Arc::new(LanceIndexStore::new(
2614 ObjectStore::local().into(),
2615 tmpdir.clone(),
2616 Arc::new(LanceCache::no_cache()),
2617 ));
2618
2619 let mut builder1 = InnerBuilder::new(0, false, TokenSetFormat::default());
2621 builder1.tokens.add("test".to_owned());
2622 builder1.posting_lists.push(PostingListBuilder::new(false));
2623 builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
2624 builder1.docs.append(100, 1); builder1.write(store.as_ref()).await.unwrap();
2626
2627 let mut builder2 = InnerBuilder::new(1, false, TokenSetFormat::default());
2629 builder2.tokens.add("test".to_owned()); builder2.posting_lists.push(PostingListBuilder::new(false));
2631 builder2.posting_lists[0].add(0, PositionRecorder::Count(2));
2632 builder2.posting_lists[0].add(1, PositionRecorder::Count(1));
2633 builder2.posting_lists[0].add(2, PositionRecorder::Count(3));
2634 builder2.posting_lists[0].add(3, PositionRecorder::Count(1));
2635 builder2.docs.append(200, 2); builder2.docs.append(201, 1); builder2.docs.append(202, 3); builder2.docs.append(203, 1); builder2.write(store.as_ref()).await.unwrap();
2640
2641 let metadata = std::collections::HashMap::from_iter(vec![
2643 (
2644 "partitions".to_owned(),
2645 serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
2646 ),
2647 (
2648 "params".to_owned(),
2649 serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
2650 ),
2651 (
2652 TOKEN_SET_FORMAT_KEY.to_owned(),
2653 TokenSetFormat::default().to_string(),
2654 ),
2655 ]);
2656 let mut writer = store
2657 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
2658 .await
2659 .unwrap();
2660 writer.finish_with_metadata(metadata).await.unwrap();
2661
2662 let cache = Arc::new(LanceCache::with_capacity(4096));
2664 let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
2665 .await
2666 .unwrap();
2667
2668 assert_eq!(index.partitions.len(), 2);
2670 assert_eq!(index.partitions[0].tokens.len(), 1);
2671 assert_eq!(index.partitions[1].tokens.len(), 1);
2672
2673 if index.partitions[0].id() == 0 {
2678 assert_eq!(index.partitions[0].inverted_list.posting_len(0), 1);
2680 assert_eq!(index.partitions[1].inverted_list.posting_len(0), 4);
2681 assert_eq!(index.partitions[0].docs.len(), 1);
2682 assert_eq!(index.partitions[1].docs.len(), 4);
2683 } else {
2684 assert_eq!(index.partitions[0].inverted_list.posting_len(0), 4);
2686 assert_eq!(index.partitions[1].inverted_list.posting_len(0), 1);
2687 assert_eq!(index.partitions[0].docs.len(), 4);
2688 assert_eq!(index.partitions[1].docs.len(), 1);
2689 }
2690
2691 index.prewarm().await.unwrap();
2693
2694 let tokens = Arc::new(Tokens::new(vec!["test".to_string()], DocType::Text));
2695 let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
2696 let prefilter = Arc::new(NoFilter);
2697 let metrics = Arc::new(NoOpMetricsCollector);
2698
2699 let (row_ids, scores) = index
2700 .bm25_search(tokens, params, Operator::Or, prefilter, metrics)
2701 .await
2702 .unwrap();
2703
2704 assert_eq!(row_ids.len(), 5, "row_ids: {:?}", row_ids);
2707 assert!(!row_ids.is_empty(), "Should find at least some documents");
2708 assert_eq!(row_ids.len(), scores.len());
2709
2710 for &score in &scores {
2712 assert!(score > 0.0, "All scores should be positive");
2713 }
2714
2715 assert!(
2717 row_ids.contains(&100),
2718 "Should contain row_id from partition 0"
2719 );
2720 assert!(
2721 row_ids.iter().any(|&id| id >= 200),
2722 "Should contain row_id from partition 1"
2723 );
2724 }
2725}