1use std::fmt::{Debug, Display};
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
7use std::{
8 cmp::{Reverse, min},
9 collections::BinaryHeap,
10};
11use std::{
12 collections::{HashMap, HashSet},
13 ops::Range,
14 time::Instant,
15};
16
17use crate::metrics::NoOpMetricsCollector;
18use crate::prefilter::NoFilter;
19use crate::scalar::registry::{TrainingCriteria, TrainingOrdering};
20use arrow::array::{FixedSizeListBuilder, Float32Builder};
21use arrow::datatypes::{self, Float32Type, Int32Type, UInt64Type};
22use arrow::{
23 array::{
24 AsArray, LargeBinaryBuilder, ListBuilder, StringBuilder, UInt32Builder, UInt64Builder,
25 },
26 buffer::{Buffer, OffsetBuffer},
27};
28use arrow::{buffer::ScalarBuffer, datatypes::UInt32Type};
29use arrow_array::{
30 Array, ArrayRef, Float32Array, LargeBinaryArray, ListArray, OffsetSizeTrait, RecordBatch,
31 UInt32Array, UInt64Array,
32};
33use arrow_schema::{DataType, Field, Schema, SchemaRef};
34use async_trait::async_trait;
35use datafusion::execution::SendableRecordBatchStream;
36use datafusion::physical_plan::metrics::Time;
37use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
38use deepsize::DeepSizeOf;
39use fst::{Automaton, IntoStreamer, Streamer};
40use futures::{FutureExt, Stream, StreamExt, TryStreamExt, stream};
41use itertools::Itertools;
42use lance_arrow::{RecordBatchExt, iter_str_array};
43use lance_core::cache::{CacheCodec, CacheKey, LanceCache, WeakLanceCache};
44use lance_core::error::{DataFusionResult, LanceOptionExt};
45use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap};
46use lance_core::utils::tokio::{get_num_compute_intensive_cpus, spawn_cpu};
47use lance_core::utils::tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS};
48use lance_core::{Error, ROW_ID, ROW_ID_FIELD, Result};
49use roaring::RoaringBitmap;
50use std::sync::LazyLock;
51use tokio::task::spawn_blocking;
52use tracing::{info, instrument};
53
54use super::encoding::PositionBlockBuilder;
55use super::iter::PostingListIterator;
56use super::{InvertedIndexBuilder, InvertedIndexParams, wand::*};
57use super::{
58 builder::{
59 BLOCK_SIZE, ScoredDoc, doc_file_path, inverted_list_schema_for_version, posting_file_path,
60 token_file_path,
61 },
62 iter::PlainPostingListIterator,
63 query::*,
64 scorer::{B, IndexBM25Scorer, K1, Scorer, idf},
65};
66use super::{
67 builder::{InnerBuilder, PositionRecorder},
68 iter::CompressedPostingListIterator,
69};
70use crate::frag_reuse::FragReuseIndex;
71use crate::pbold;
72use crate::progress::IndexBuildProgress;
73use crate::scalar::inverted::scorer::MemBM25Scorer;
74use crate::scalar::inverted::tokenizer::document_tokenizer::LanceTokenizer;
75use crate::scalar::{
76 AnyQuery, BuiltinIndexType, CreatedIndex, IndexReader, IndexStore, MetricsCollector,
77 OldIndexDataFilter, ScalarIndex, ScalarIndexParams, SearchResult, TokenQuery, UpdateCriteria,
78};
79use crate::{FtsPrewarmOptions, Index};
80use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys};
81use std::str::FromStr;
82
83pub const INVERTED_INDEX_VERSION_V1: u32 = 1;
87pub const INVERTED_INDEX_VERSION_V2: u32 = 2;
88pub const TOKENS_FILE: &str = "tokens.lance";
89pub const INVERT_LIST_FILE: &str = "invert.lance";
90pub const DOCS_FILE: &str = "docs.lance";
91pub const METADATA_FILE: &str = "metadata.lance";
92
93pub const TOKEN_COL: &str = "_token";
94pub const TOKEN_ID_COL: &str = "_token_id";
95pub const TOKEN_FST_BYTES_COL: &str = "_token_fst_bytes";
96pub const TOKEN_NEXT_ID_COL: &str = "_token_next_id";
97pub const TOKEN_TOTAL_LENGTH_COL: &str = "_token_total_length";
98pub const FREQUENCY_COL: &str = "_frequency";
99pub const POSITION_COL: &str = "_position";
100pub const COMPRESSED_POSITION_COL: &str = "_compressed_position";
101pub const POSITION_BLOCK_OFFSET_COL: &str = "_position_block_offset";
102pub const POSTING_COL: &str = "_posting";
103pub const MAX_SCORE_COL: &str = "_max_score";
104pub const LENGTH_COL: &str = "_length";
105pub const BLOCK_MAX_SCORE_COL: &str = "_block_max_score";
106pub const NUM_TOKEN_COL: &str = "_num_tokens";
107pub const SCORE_COL: &str = "_score";
108pub const TOKEN_SET_FORMAT_KEY: &str = "token_set_format";
109pub const POSTING_TAIL_CODEC_KEY: &str = "posting_tail_codec";
110pub const POSITIONS_LAYOUT_KEY: &str = "positions_layout";
111pub const POSITIONS_CODEC_KEY: &str = "positions_codec";
112pub const POSTING_TAIL_CODEC_FIXED32_V1: &str = "fixed32_v1";
113pub const POSTING_TAIL_CODEC_VARINT_DELTA_V1: &str = "varint_delta_v1";
114pub const POSITIONS_LAYOUT_SHARED_STREAM_V2: &str = "shared_stream_v2";
115pub const POSITIONS_CODEC_VARINT_DOC_DELTA_V2: &str = "varint_doc_delta_v2";
116pub const POSITIONS_CODEC_PACKED_DELTA_V1: &str = "packed_delta_v1";
117pub const DELETED_FRAGMENTS_COL: &str = "deleted_fragments";
118
119pub const ESTIMATED_MAX_TOKENS_PER_ROW: usize = 4 * 1024;
121
122pub static SCORE_FIELD: LazyLock<Field> =
123 LazyLock::new(|| Field::new(SCORE_COL, DataType::Float32, true));
124pub static FTS_SCHEMA: LazyLock<SchemaRef> =
125 LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone(), SCORE_FIELD.clone()])));
126static ROW_ID_SCHEMA: LazyLock<SchemaRef> =
127 LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()])));
128
129fn resolve_fts_format_version(
130 value: Option<&str>,
131) -> std::result::Result<InvertedListFormatVersion, Error> {
132 value.unwrap_or("1").parse()
133}
134
135pub fn current_fts_format_version() -> InvertedListFormatVersion {
136 resolve_fts_format_version(std::env::var("LANCE_FTS_FORMAT_VERSION").ok().as_deref())
137 .expect("failed to parse LANCE_FTS_FORMAT_VERSION")
138}
139
140pub fn max_supported_fts_format_version() -> InvertedListFormatVersion {
141 InvertedListFormatVersion::V2
142}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
145pub enum InvertedListFormatVersion {
146 #[default]
147 V1,
148 V2,
149}
150
151impl InvertedListFormatVersion {
152 pub fn from_posting_tail_codec(codec: PostingTailCodec) -> Self {
153 match codec {
154 PostingTailCodec::Fixed32 => Self::V1,
155 PostingTailCodec::VarintDelta => Self::V2,
156 }
157 }
158
159 pub fn index_version(self) -> u32 {
160 match self {
161 Self::V1 => INVERTED_INDEX_VERSION_V1,
162 Self::V2 => INVERTED_INDEX_VERSION_V2,
163 }
164 }
165
166 pub fn posting_tail_codec(self) -> PostingTailCodec {
167 match self {
168 Self::V1 => PostingTailCodec::Fixed32,
169 Self::V2 => PostingTailCodec::VarintDelta,
170 }
171 }
172
173 pub fn position_codec(self) -> Option<PositionStreamCodec> {
174 match self {
175 Self::V1 => None,
176 Self::V2 => Some(PositionStreamCodec::PackedDelta),
177 }
178 }
179
180 pub fn uses_shared_position_stream(self) -> bool {
181 matches!(self, Self::V2)
182 }
183}
184
185impl FromStr for InvertedListFormatVersion {
186 type Err = Error;
187
188 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
189 match s.trim() {
190 "1" | "v1" | "V1" => Ok(Self::V1),
191 "2" | "v2" | "V2" => Ok(Self::V2),
192 other => Err(Error::index(format!(
193 "unsupported FTS format version {}, expected 1 or 2",
194 other
195 ))),
196 }
197 }
198}
199
200#[derive(Debug)]
201struct PartitionCandidates {
202 tokens_by_position: Vec<String>,
203 candidates: Vec<DocCandidate>,
204}
205
206impl PartitionCandidates {
207 fn empty() -> Self {
208 Self {
209 tokens_by_position: Vec::new(),
210 candidates: Vec::new(),
211 }
212 }
213}
214
215#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)]
216pub enum TokenSetFormat {
217 Arrow,
218 #[default]
219 Fst,
220}
221
222impl Display for TokenSetFormat {
223 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 match self {
225 Self::Arrow => f.write_str("arrow"),
226 Self::Fst => f.write_str("fst"),
227 }
228 }
229}
230
231impl FromStr for TokenSetFormat {
232 type Err = Error;
233
234 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
235 match s.trim() {
236 "" => Ok(Self::Arrow),
237 "arrow" => Ok(Self::Arrow),
238 "fst" => Ok(Self::Fst),
239 other => Err(Error::index(format!(
240 "unsupported token set format {}",
241 other
242 ))),
243 }
244 }
245}
246
247impl DeepSizeOf for TokenSetFormat {
248 fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
249 0
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
254pub enum PositionStreamCodec {
255 VarintDocDelta,
256 #[default]
257 PackedDelta,
258}
259
260#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
261pub enum PostingTailCodec {
262 Fixed32,
263 #[default]
264 VarintDelta,
265}
266
267impl PostingTailCodec {
268 pub fn as_str(self) -> &'static str {
269 match self {
270 Self::Fixed32 => POSTING_TAIL_CODEC_FIXED32_V1,
271 Self::VarintDelta => POSTING_TAIL_CODEC_VARINT_DELTA_V1,
272 }
273 }
274
275 fn from_metadata_value(value: &str) -> Result<Self> {
276 match value.trim() {
277 POSTING_TAIL_CODEC_FIXED32_V1 => Ok(Self::Fixed32),
278 POSTING_TAIL_CODEC_VARINT_DELTA_V1 => Ok(Self::VarintDelta),
279 other => Err(Error::index(format!(
280 "unsupported posting tail codec {}",
281 other
282 ))),
283 }
284 }
285}
286
287pub(super) fn parse_posting_tail_codec(
288 metadata: &HashMap<String, String>,
289) -> Result<PostingTailCodec> {
290 Ok(metadata
291 .get(POSTING_TAIL_CODEC_KEY)
292 .map(|codec| PostingTailCodec::from_metadata_value(codec))
293 .transpose()?
294 .unwrap_or(PostingTailCodec::Fixed32))
295}
296
297impl PositionStreamCodec {
298 pub fn as_str(self) -> &'static str {
299 match self {
300 Self::VarintDocDelta => POSITIONS_CODEC_VARINT_DOC_DELTA_V2,
301 Self::PackedDelta => POSITIONS_CODEC_PACKED_DELTA_V1,
302 }
303 }
304
305 fn from_metadata_value(value: &str) -> Result<Self> {
306 match value.trim() {
307 POSITIONS_CODEC_VARINT_DOC_DELTA_V2 => Ok(Self::VarintDocDelta),
308 POSITIONS_CODEC_PACKED_DELTA_V1 => Ok(Self::PackedDelta),
309 other => Err(Error::index(format!(
310 "unsupported positions codec {}",
311 other
312 ))),
313 }
314 }
315}
316
317fn parse_shared_position_codec(metadata: &HashMap<String, String>) -> Result<PositionStreamCodec> {
318 if let Some(codec) = metadata.get(POSITIONS_CODEC_KEY) {
319 return PositionStreamCodec::from_metadata_value(codec);
320 }
321
322 match metadata
323 .get(POSITIONS_LAYOUT_KEY)
324 .map(|layout| layout.as_str())
325 {
326 Some(POSITIONS_LAYOUT_SHARED_STREAM_V2) => Ok(PositionStreamCodec::VarintDocDelta),
327 _ => Ok(PositionStreamCodec::VarintDocDelta),
328 }
329}
330
331pub(super) fn parse_format_version_from_metadata(
332 metadata: &HashMap<String, String>,
333) -> Result<InvertedListFormatVersion> {
334 if metadata.contains_key(POSITIONS_CODEC_KEY) || metadata.contains_key(POSITIONS_LAYOUT_KEY) {
335 return Ok(InvertedListFormatVersion::V2);
336 }
337 if parse_posting_tail_codec(metadata)? == PostingTailCodec::VarintDelta {
338 Ok(InvertedListFormatVersion::V2)
339 } else {
340 Ok(InvertedListFormatVersion::V1)
341 }
342}
343
344#[derive(Clone)]
345pub struct InvertedIndex {
346 params: InvertedIndexParams,
347 store: Arc<dyn IndexStore>,
348 tokenizer: Box<dyn LanceTokenizer>,
349 token_set_format: TokenSetFormat,
350 pub(crate) partitions: Vec<Arc<InvertedPartition>>,
351 deleted_fragments: RoaringBitmap,
354}
355
356impl Debug for InvertedIndex {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 f.debug_struct("InvertedIndex")
359 .field("params", &self.params)
360 .field("token_set_format", &self.token_set_format)
361 .field("partitions", &self.partitions)
362 .field("deleted_fragments", &self.deleted_fragments)
363 .finish()
364 }
365}
366
367impl DeepSizeOf for InvertedIndex {
368 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
369 self.partitions.deep_size_of_children(context)
370 }
371}
372
373impl InvertedIndex {
374 fn format_version(&self) -> InvertedListFormatVersion {
375 self.partitions
376 .first()
377 .map(|partition| {
378 InvertedListFormatVersion::from_posting_tail_codec(
379 partition.inverted_list.posting_tail_codec(),
380 )
381 })
382 .unwrap_or_else(current_fts_format_version)
383 }
384
385 fn index_version(&self) -> u32 {
386 match self.token_set_format {
387 TokenSetFormat::Arrow => 0,
388 TokenSetFormat::Fst => self.format_version().index_version(),
389 }
390 }
391
392 fn posting_tail_codec(&self) -> PostingTailCodec {
393 self.partitions
394 .first()
395 .map(|partition| partition.inverted_list.posting_tail_codec())
396 .unwrap_or_default()
397 }
398
399 fn to_builder(&self) -> InvertedIndexBuilder {
400 self.to_builder_with_offset(None)
401 }
402
403 fn to_builder_with_offset(&self, fragment_mask: Option<u64>) -> InvertedIndexBuilder {
404 if self.is_legacy() {
405 InvertedIndexBuilder::from_existing_index(
407 self.params.clone(),
408 None,
409 Vec::new(),
410 self.token_set_format,
411 fragment_mask,
412 self.deleted_fragments.clone(),
413 )
414 .with_posting_tail_codec(self.posting_tail_codec())
415 } else {
416 let partitions = match fragment_mask {
417 Some(fragment_mask) => self
418 .partitions
419 .iter()
420 .filter(|part| part.belongs_to_fragment(fragment_mask))
424 .map(|part| part.id())
425 .collect(),
426 None => self.partitions.iter().map(|part| part.id()).collect(),
427 };
428
429 InvertedIndexBuilder::from_existing_index(
430 self.params.clone(),
431 Some(self.store.clone()),
432 partitions,
433 self.token_set_format,
434 fragment_mask,
435 self.deleted_fragments.clone(),
436 )
437 .with_format_version(self.format_version())
438 }
439 }
440
441 pub fn tokenizer(&self) -> Box<dyn LanceTokenizer> {
442 self.tokenizer.clone()
443 }
444
445 pub fn params(&self) -> &InvertedIndexParams {
446 &self.params
447 }
448
449 pub fn partition_count(&self) -> usize {
451 self.partitions.len()
452 }
453 pub fn deleted_fragments(&self) -> &RoaringBitmap {
459 &self.deleted_fragments
460 }
461
462 pub async fn merge_segments(
463 segments: &[Arc<Self>],
464 new_data: SendableRecordBatchStream,
465 dest_store: &dyn IndexStore,
466 old_data_filter: Option<OldIndexDataFilter>,
467 progress: Arc<dyn IndexBuildProgress>,
468 ) -> Result<CreatedIndex> {
469 let Some(first) = segments.first() else {
470 return Err(Error::invalid_input(
471 "cannot merge inverted index without at least one source segment".to_string(),
472 ));
473 };
474
475 for segment in segments.iter().skip(1) {
476 if segment.params != first.params {
477 return Err(Error::index(
478 "cannot merge inverted index segments with different parameters".to_string(),
479 ));
480 }
481 if segment.token_set_format != first.token_set_format {
482 return Err(Error::index(
483 "cannot merge inverted index segments with different token set formats"
484 .to_string(),
485 ));
486 }
487 if segment.format_version() != first.format_version() {
488 return Err(Error::index(
489 "cannot merge inverted index segments with different format versions"
490 .to_string(),
491 ));
492 }
493 if segment.posting_tail_codec() != first.posting_tail_codec() {
494 return Err(Error::index(
495 "cannot merge inverted index segments with different posting tail codecs"
496 .to_string(),
497 ));
498 }
499 }
500
501 let mut builder = InvertedIndexBuilder::new(first.params.clone()).with_progress(progress);
502 builder = builder
503 .with_token_set_format(first.token_set_format)
504 .with_format_version(first.format_version())
505 .with_posting_tail_codec(first.posting_tail_codec());
506 builder
507 .update_from_segments(new_data, dest_store, segments, old_data_filter)
508 .await?;
509
510 let details = pbold::InvertedIndexDetails::try_from(&first.params)?;
511
512 Ok(CreatedIndex {
513 index_details: prost_types::Any::from_msg(&details).unwrap(),
514 index_version: first.index_version(),
515 files: Some(dest_store.list_files_with_sizes().await?),
516 })
517 }
518
519 pub fn bm25_base_scorer(&self, query_tokens: &Tokens) -> MemBM25Scorer {
520 let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
521 let token_docs = query_tokens
522 .into_iter()
523 .map(|token| (token.to_string(), scorer.num_docs_containing_token(token)))
524 .collect::<HashMap<_, _>>();
525 MemBM25Scorer::new(scorer.total_tokens(), scorer.num_docs(), token_docs)
526 }
527
528 pub fn bm25_stats_for_terms(&self, terms: &[String]) -> (u64, usize, Vec<usize>) {
529 let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
530 let token_docs = terms
531 .iter()
532 .map(|term| scorer.num_docs_containing_token(term))
533 .collect();
534 (scorer.total_tokens(), scorer.num_docs(), token_docs)
535 }
536
537 pub fn expand_fuzzy_tokens(&self, tokens: &Tokens, params: &FtsSearchParams) -> Result<Tokens> {
539 let mut expanded_tokens = Vec::new();
540 let mut expanded_positions = Vec::new();
541 let mut seen = HashSet::new();
542 for partition in &self.partitions {
543 let expanded = partition.expand_fuzzy(tokens, params)?;
544 for idx in 0..expanded.len() {
545 let token = expanded.get_token(idx);
546 if seen.insert(token.to_string()) {
547 expanded_tokens.push(token.to_string());
548 expanded_positions.push(expanded.position(idx));
549 }
550 }
551 }
552 Ok(Tokens::with_positions(
553 expanded_tokens,
554 expanded_positions,
555 tokens.token_type().clone(),
556 ))
557 }
558
559 #[instrument(level = "debug", skip_all)]
564 pub async fn bm25_search(
565 &self,
566 tokens: Arc<Tokens>,
567 params: Arc<FtsSearchParams>,
568 operator: Operator,
569 prefilter: Arc<dyn PreFilter>,
570 metrics: Arc<dyn MetricsCollector>,
571 base_scorer: Option<&MemBM25Scorer>,
572 ) -> Result<(Vec<u64>, Vec<f32>)> {
573 let local_scorer;
574 let scorer: &dyn Scorer = if let Some(base_scorer) = base_scorer {
575 base_scorer
576 } else {
577 local_scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
578 &local_scorer
579 };
580
581 let limit = params.limit.unwrap_or(usize::MAX);
582 if limit == 0 {
583 return Ok((Vec::new(), Vec::new()));
584 }
585 let mask = prefilter.mask();
586
587 let mut candidates = BinaryHeap::new();
588 let parts = self
589 .partitions
590 .iter()
591 .map(|part| {
592 let part = part.clone();
593 let tokens = tokens.clone();
594 let params = params.clone();
595 let mask = mask.clone();
596 let metrics = metrics.clone();
597 async move {
598 let postings = part
599 .load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref())
600 .await?;
601 if postings.is_empty() {
602 return Result::Ok(PartitionCandidates::empty());
603 }
604 let max_position = postings
605 .iter()
606 .map(|posting| posting.term_index() as usize)
607 .max()
608 .unwrap_or_default();
609 let mut tokens_by_position = vec![String::new(); max_position + 1];
610 for posting in &postings {
611 let idx = posting.term_index() as usize;
612 tokens_by_position[idx] = posting.token().to_owned();
613 }
614 let params = params.clone();
615 let mask = mask.clone();
616 let metrics = metrics.clone();
617 spawn_cpu(move || {
618 let candidates = part.bm25_search(
619 params.as_ref(),
620 operator,
621 mask,
622 postings,
623 metrics.as_ref(),
624 )?;
625 Ok(PartitionCandidates {
626 tokens_by_position,
627 candidates,
628 })
629 })
630 .await
631 }
632 })
633 .collect::<Vec<_>>();
634 let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
635 let mut idf_cache: HashMap<String, f32> = HashMap::new();
636 while let Some(res) = parts.try_next().await? {
637 if res.candidates.is_empty() {
638 continue;
639 }
640 let mut idf_by_position = Vec::with_capacity(res.tokens_by_position.len());
641 for token in &res.tokens_by_position {
642 let idf_weight = match idf_cache.get(token) {
643 Some(weight) => *weight,
644 None => {
645 let weight = scorer.query_weight(token);
646 idf_cache.insert(token.clone(), weight);
647 weight
648 }
649 };
650 idf_by_position.push(idf_weight);
651 }
652 for DocCandidate {
653 row_id,
654 freqs,
655 doc_length,
656 } in res.candidates
657 {
658 let mut score = 0.0;
659 for (term_index, freq) in freqs.into_iter() {
660 debug_assert!((term_index as usize) < idf_by_position.len());
661 score +=
662 idf_by_position[term_index as usize] * scorer.doc_weight(freq, doc_length);
663 }
664 if candidates.len() < limit {
665 candidates.push(Reverse(ScoredDoc::new(row_id, score)));
666 } else if candidates.peek().unwrap().0.score.0 < score {
667 candidates.pop();
668 candidates.push(Reverse(ScoredDoc::new(row_id, score)));
669 }
670 }
671 }
672
673 Ok(candidates
674 .into_sorted_vec()
675 .into_iter()
676 .map(|Reverse(doc)| (doc.row_id, doc.score.0))
677 .unzip())
678 }
679
680 async fn load_legacy_index(
681 store: Arc<dyn IndexStore>,
682 frag_reuse_index: Option<Arc<FragReuseIndex>>,
683 index_cache: &LanceCache,
684 ) -> Result<Arc<Self>> {
685 log::warn!("loading legacy FTS index");
686 let tokens_fut = tokio::spawn({
687 let store = store.clone();
688 async move {
689 let token_reader = store.open_index_file(TOKENS_FILE).await?;
690 let tokenizer = token_reader
691 .schema()
692 .metadata
693 .get("tokenizer")
694 .map(|s| serde_json::from_str::<InvertedIndexParams>(s))
695 .transpose()?
696 .unwrap_or_default();
697 let tokens = TokenSet::load(token_reader, TokenSetFormat::Arrow).await?;
698 Result::Ok((tokenizer, tokens))
699 }
700 });
701 let invert_list_fut = tokio::spawn({
702 let store = store.clone();
703 let index_cache_clone = index_cache.clone();
704 async move {
705 let invert_list_reader = store.open_index_file(INVERT_LIST_FILE).await?;
706 let invert_list =
707 PostingListReader::try_new(invert_list_reader, &index_cache_clone).await?;
708 Result::Ok(Arc::new(invert_list))
709 }
710 });
711 let docs_fut = tokio::spawn({
712 let store = store.clone();
713 async move {
714 let docs_reader = store.open_index_file(DOCS_FILE).await?;
715 let docs = DocSet::load(docs_reader, true, frag_reuse_index).await?;
716 Result::Ok(docs)
717 }
718 });
719
720 let (tokenizer_config, tokens) = tokens_fut.await??;
721 let inverted_list = invert_list_fut.await??;
722 let docs = docs_fut.await??;
723
724 let tokenizer = tokenizer_config.build()?;
725
726 Ok(Arc::new(Self {
727 params: tokenizer_config,
728 store: store.clone(),
729 tokenizer,
730 token_set_format: TokenSetFormat::Arrow,
731 partitions: vec![Arc::new(InvertedPartition {
732 id: 0,
733 store,
734 tokens,
735 inverted_list,
736 docs,
737 token_set_format: TokenSetFormat::Arrow,
738 })],
739 deleted_fragments: RoaringBitmap::new(),
740 }))
741 }
742
743 pub fn is_legacy(&self) -> bool {
744 self.partitions.len() == 1 && self.partitions[0].is_legacy()
745 }
746
747 pub async fn load(
748 store: Arc<dyn IndexStore>,
749 frag_reuse_index: Option<Arc<FragReuseIndex>>,
750 index_cache: &LanceCache,
751 ) -> Result<Arc<Self>>
752 where
753 Self: Sized,
754 {
755 match store.open_index_file(METADATA_FILE).await {
760 Ok(reader) => {
761 let params = reader
762 .schema()
763 .metadata
764 .get("params")
765 .ok_or(Error::index("params not found in metadata".to_owned()))?;
766 let params = serde_json::from_str::<InvertedIndexParams>(params)?;
767 let partitions = reader
768 .schema()
769 .metadata
770 .get("partitions")
771 .ok_or(Error::index("partitions not found in metadata".to_owned()))?;
772 let partitions: Vec<u64> = serde_json::from_str(partitions)?;
773 let token_set_format = reader
774 .schema()
775 .metadata
776 .get(TOKEN_SET_FORMAT_KEY)
777 .map(|name| TokenSetFormat::from_str(name))
778 .transpose()?
779 .unwrap_or(TokenSetFormat::Arrow);
780
781 let deleted_fragments = if reader.num_rows() > 0 {
783 let metadata_batch = reader.read_range(0..1, None).await?;
784 if let Some(col) = metadata_batch.column_by_name(DELETED_FRAGMENTS_COL) {
785 let arr = col.as_binary_opt::<i32>().expect_ok()?;
786 RoaringBitmap::deserialize_from(arr.value(0))?
787 } else {
788 RoaringBitmap::new()
789 }
790 } else {
791 RoaringBitmap::new()
792 };
793
794 let format = token_set_format;
795 let partitions = partitions.into_iter().map(|id| {
796 let store = store.clone();
797 let frag_reuse_index_clone = frag_reuse_index.clone();
798 let index_cache_for_part =
799 index_cache.with_key_prefix(format!("part-{}", id).as_str());
800 let token_set_format = format;
801 async move {
802 Result::Ok(Arc::new(
803 InvertedPartition::load(
804 store,
805 id,
806 frag_reuse_index_clone,
807 &index_cache_for_part,
808 token_set_format,
809 )
810 .await?,
811 ))
812 }
813 });
814 let partitions = stream::iter(partitions)
815 .buffer_unordered(store.io_parallelism())
816 .try_collect::<Vec<_>>()
817 .await?;
818
819 let tokenizer = params.build()?;
820 Ok(Arc::new(Self {
821 params,
822 store,
823 tokenizer,
824 token_set_format,
825 partitions,
826 deleted_fragments,
827 }))
828 }
829 Err(_) => {
830 Self::load_legacy_index(store, frag_reuse_index, index_cache).await
832 }
833 }
834 }
835}
836
837#[async_trait]
838impl Index for InvertedIndex {
839 fn as_any(&self) -> &dyn std::any::Any {
840 self
841 }
842
843 fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
844 self
845 }
846
847 fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
848 Err(Error::invalid_input(
849 "inverted index cannot be cast to vector index",
850 ))
851 }
852
853 fn statistics(&self) -> Result<serde_json::Value> {
854 let num_tokens = self
855 .partitions
856 .iter()
857 .map(|part| part.tokens.len())
858 .sum::<usize>();
859 let num_docs = self
860 .partitions
861 .iter()
862 .map(|part| part.docs.len())
863 .sum::<usize>();
864 Ok(serde_json::json!({
865 "params": self.params,
866 "num_tokens": num_tokens,
867 "num_docs": num_docs,
868 }))
869 }
870
871 async fn prewarm(&self) -> Result<()> {
872 self.prewarm_with_options(&FtsPrewarmOptions::default())
873 .await
874 }
875
876 fn index_type(&self) -> crate::IndexType {
877 crate::IndexType::Inverted
878 }
879
880 async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
881 unimplemented!()
882 }
883}
884
885impl InvertedIndex {
886 pub async fn prewarm_with_options(&self, options: &FtsPrewarmOptions) -> Result<()> {
887 let with_position = options.with_position;
888 let io_parallelism = self.store.io_parallelism();
889 let prewarm_futures = self
890 .partitions
891 .iter()
892 .map(Arc::clone)
893 .map(|part| async move {
894 part.inverted_list
895 .prewarm_posting_lists(with_position)
896 .await?;
897 Result::Ok(())
898 });
899 stream::iter(prewarm_futures)
900 .buffer_unordered(io_parallelism)
901 .try_collect::<Vec<_>>()
902 .await?;
903 Ok(())
904 }
905 async fn do_search(&self, text: &str) -> Result<RecordBatch> {
907 let params = FtsSearchParams::new();
908 let mut tokenizer = self.tokenizer.clone();
909 let tokens = collect_query_tokens(text, &mut tokenizer);
910
911 let (doc_ids, _) = self
912 .bm25_search(
913 Arc::new(tokens),
914 params.into(),
915 Operator::And,
916 Arc::new(NoFilter),
917 Arc::new(NoOpMetricsCollector),
918 None,
919 )
920 .boxed()
921 .await?;
922
923 Ok(RecordBatch::try_new(
924 ROW_ID_SCHEMA.clone(),
925 vec![Arc::new(UInt64Array::from(doc_ids))],
926 )?)
927 }
928}
929
930#[async_trait]
931impl ScalarIndex for InvertedIndex {
932 #[instrument(level = "debug", skip_all)]
934 async fn search(
935 &self,
936 query: &dyn AnyQuery,
937 _metrics: &dyn MetricsCollector,
938 ) -> Result<SearchResult> {
939 let query = query.as_any().downcast_ref::<TokenQuery>().unwrap();
940
941 match query {
942 TokenQuery::TokensContains(text) => {
943 let records = self.do_search(text).await?;
944 let row_ids = records
945 .column(0)
946 .as_any()
947 .downcast_ref::<UInt64Array>()
948 .unwrap();
949 let row_ids = row_ids.iter().flatten().collect_vec();
950 Ok(SearchResult::at_most(RowAddrTreeMap::from_iter(row_ids)))
951 }
952 }
953 }
954
955 fn can_remap(&self) -> bool {
956 true
957 }
958
959 async fn remap(
960 &self,
961 mapping: &HashMap<u64, Option<u64>>,
962 dest_store: &dyn IndexStore,
963 ) -> Result<CreatedIndex> {
964 self.to_builder()
965 .remap(mapping, self.store.clone(), dest_store)
966 .await?;
967
968 let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
969
970 Ok(CreatedIndex {
971 index_details: prost_types::Any::from_msg(&details).unwrap(),
972 index_version: self.index_version(),
973 files: Some(dest_store.list_files_with_sizes().await?),
974 })
975 }
976
977 async fn update(
978 &self,
979 new_data: SendableRecordBatchStream,
980 dest_store: &dyn IndexStore,
981 old_data_filter: Option<crate::scalar::OldIndexDataFilter>,
982 ) -> Result<CreatedIndex> {
983 self.to_builder()
984 .update(new_data, dest_store, old_data_filter)
985 .await?;
986
987 let details = pbold::InvertedIndexDetails::try_from(&self.params)?;
988
989 Ok(CreatedIndex {
990 index_details: prost_types::Any::from_msg(&details).unwrap(),
991 index_version: self.index_version(),
992 files: Some(dest_store.list_files_with_sizes().await?),
993 })
994 }
995
996 fn update_criteria(&self) -> UpdateCriteria {
997 let criteria = TrainingCriteria::new(TrainingOrdering::None).with_row_id();
998 if self.is_legacy() {
999 UpdateCriteria::requires_old_data(criteria)
1000 } else {
1001 UpdateCriteria::only_new_data(criteria)
1002 }
1003 }
1004
1005 fn derive_index_params(&self) -> Result<ScalarIndexParams> {
1006 let mut params = self.params.clone();
1007 if params.base_tokenizer.is_empty() {
1008 params.base_tokenizer = "simple".to_string();
1009 }
1010
1011 let params_json = serde_json::to_string(¶ms)?;
1012
1013 Ok(ScalarIndexParams {
1014 index_type: BuiltinIndexType::Inverted.as_str().to_string(),
1015 params: Some(params_json),
1016 })
1017 }
1018}
1019
1020#[derive(Debug, Clone, DeepSizeOf)]
1021pub struct InvertedPartition {
1022 id: u64,
1024 store: Arc<dyn IndexStore>,
1025 pub(crate) tokens: TokenSet,
1026 pub(crate) inverted_list: Arc<PostingListReader>,
1027 pub(crate) docs: DocSet,
1028 token_set_format: TokenSetFormat,
1029}
1030
1031impl InvertedPartition {
1032 pub fn belongs_to_fragment(&self, fragment_mask: u64) -> bool {
1043 (self.id() & fragment_mask) == fragment_mask
1044 }
1045
1046 pub fn id(&self) -> u64 {
1047 self.id
1048 }
1049
1050 pub fn store(&self) -> &dyn IndexStore {
1051 self.store.as_ref()
1052 }
1053
1054 pub fn is_legacy(&self) -> bool {
1055 self.inverted_list.lengths.is_none()
1056 }
1057
1058 pub async fn load(
1059 store: Arc<dyn IndexStore>,
1060 id: u64,
1061 frag_reuse_index: Option<Arc<FragReuseIndex>>,
1062 index_cache: &LanceCache,
1063 token_set_format: TokenSetFormat,
1064 ) -> Result<Self> {
1065 let token_file = store.open_index_file(&token_file_path(id)).await?;
1066 let tokens = TokenSet::load(token_file, token_set_format).await?;
1067 let invert_list_file = store.open_index_file(&posting_file_path(id)).await?;
1068 let inverted_list = PostingListReader::try_new(invert_list_file, index_cache).await?;
1069 let docs_file = store.open_index_file(&doc_file_path(id)).await?;
1070 let docs = DocSet::load(docs_file, false, frag_reuse_index).await?;
1071
1072 Ok(Self {
1073 id,
1074 store,
1075 tokens,
1076 inverted_list: Arc::new(inverted_list),
1077 docs,
1078 token_set_format,
1079 })
1080 }
1081
1082 fn map(&self, token: &str) -> Option<u32> {
1083 self.tokens.get(token)
1084 }
1085
1086 pub fn expand_fuzzy(&self, tokens: &Tokens, params: &FtsSearchParams) -> Result<Tokens> {
1087 let mut new_tokens = Vec::with_capacity(min(tokens.len(), params.max_expansions));
1088 for token in tokens {
1089 let fuzziness = match params.fuzziness {
1090 Some(fuzziness) => fuzziness,
1091 None => MatchQuery::auto_fuzziness(token),
1092 };
1093 let lev = fst::automaton::Levenshtein::new(token, fuzziness)
1094 .map_err(|e| Error::index(format!("failed to construct the fuzzy query: {}", e)))?;
1095
1096 let base_len = tokens.token_type().prefix_len(token) as u32;
1097 if let TokenMap::Fst(ref map) = self.tokens.tokens {
1098 match base_len + params.prefix_length {
1099 0 => take_fst_keys(map.search(lev), &mut new_tokens, params.max_expansions),
1100 prefix_length => {
1101 let prefix = &token[..min(prefix_length as usize, token.len())];
1102 let prefix = fst::automaton::Str::new(prefix).starts_with();
1103 take_fst_keys(
1104 map.search(lev.intersection(prefix)),
1105 &mut new_tokens,
1106 params.max_expansions,
1107 )
1108 }
1109 }
1110 } else {
1111 return Err(Error::index(
1112 "tokens is not fst, which is not expected".to_owned(),
1113 ));
1114 }
1115 }
1116 Ok(Tokens::new(new_tokens, tokens.token_type().clone()))
1117 }
1118
1119 #[instrument(level = "debug", skip_all)]
1123 pub async fn load_posting_lists(
1124 &self,
1125 tokens: &Tokens,
1126 params: &FtsSearchParams,
1127 metrics: &dyn MetricsCollector,
1128 ) -> Result<Vec<PostingIterator>> {
1129 let is_fuzzy = matches!(params.fuzziness, Some(n) if n != 0);
1130 let is_phrase_query = params.phrase_slop.is_some();
1131 let tokens = match is_fuzzy {
1132 true => self.expand_fuzzy(tokens, params)?,
1133 false => tokens.clone(),
1134 };
1135 let token_positions = (0..tokens.len())
1136 .map(|index| tokens.position(index))
1137 .collect::<Vec<_>>();
1138 let mut token_ids = Vec::with_capacity(tokens.len());
1139 for (index, token) in tokens.into_iter().enumerate() {
1140 let token_id = self.map(&token);
1141 if let Some(token_id) = token_id {
1142 token_ids.push((token_id, token, token_positions[index]));
1143 } else if is_phrase_query {
1144 return Ok(Vec::new());
1146 }
1147 }
1148 if token_ids.is_empty() {
1149 return Ok(Vec::new());
1150 }
1151 if !is_phrase_query {
1152 token_ids.sort_unstable_by_key(|(token_id, _, _)| *token_id);
1153 token_ids.dedup_by_key(|(token_id, _, _)| *token_id);
1154 }
1155
1156 let num_docs = self.docs.len();
1157 stream::iter(token_ids)
1158 .map(|(token_id, token, position)| async move {
1159 let posting = self
1160 .inverted_list
1161 .posting_list(token_id, is_phrase_query, metrics)
1162 .await?;
1163
1164 let query_weight = idf(posting.len(), num_docs);
1165
1166 Result::Ok(PostingIterator::with_query_weight(
1167 token,
1168 token_id,
1169 position,
1170 query_weight,
1171 posting,
1172 num_docs,
1173 ))
1174 })
1175 .buffered(self.store.io_parallelism())
1176 .try_collect::<Vec<_>>()
1177 .await
1178 }
1179
1180 #[instrument(level = "debug", skip_all)]
1181 pub fn bm25_search(
1182 &self,
1183 params: &FtsSearchParams,
1184 operator: Operator,
1185 mask: Arc<RowAddrMask>,
1186 postings: Vec<PostingIterator>,
1187 metrics: &dyn MetricsCollector,
1188 ) -> Result<Vec<DocCandidate>> {
1189 if postings.is_empty() {
1190 return Ok(Vec::new());
1191 }
1192
1193 let scorer = IndexBM25Scorer::new(std::iter::once(self));
1195 let mut wand = Wand::new(operator, postings.into_iter(), &self.docs, scorer);
1196 let hits = wand.search(params, mask, metrics)?;
1197 Ok(hits)
1199 }
1200
1201 pub async fn into_builder(self) -> Result<InnerBuilder> {
1202 let mut builder = InnerBuilder::new_with_posting_tail_codec(
1203 self.id,
1204 self.inverted_list.has_positions(),
1205 self.token_set_format,
1206 self.inverted_list.posting_tail_codec(),
1207 );
1208 builder.tokens = self.tokens.into_mutable();
1209 builder.docs = self.docs;
1210
1211 builder
1212 .posting_lists
1213 .reserve_exact(self.inverted_list.len());
1214 for posting_list in self
1215 .inverted_list
1216 .read_all(self.inverted_list.has_positions())
1217 .await?
1218 {
1219 let posting_list = posting_list?;
1220 builder
1221 .posting_lists
1222 .push(posting_list.into_builder(&builder.docs));
1223 }
1224 Ok(builder)
1225 }
1226}
1227
1228#[derive(Debug, Clone)]
1231pub enum TokenMap {
1232 HashMap(HashMap<String, u32>),
1233 Fst(fst::Map<Vec<u8>>),
1234}
1235
1236impl Default for TokenMap {
1237 fn default() -> Self {
1238 Self::HashMap(HashMap::new())
1239 }
1240}
1241
1242impl DeepSizeOf for TokenMap {
1243 fn deep_size_of_children(&self, ctx: &mut deepsize::Context) -> usize {
1244 match self {
1245 Self::HashMap(map) => map.deep_size_of_children(ctx),
1246 Self::Fst(map) => map.as_fst().size(),
1247 }
1248 }
1249}
1250
1251impl TokenMap {
1252 pub fn len(&self) -> usize {
1253 match self {
1254 Self::HashMap(map) => map.len(),
1255 Self::Fst(map) => map.len(),
1256 }
1257 }
1258
1259 pub fn is_empty(&self) -> bool {
1260 self.len() == 0
1261 }
1262}
1263
1264#[derive(Debug, Clone, Default, DeepSizeOf)]
1266pub struct TokenSet {
1267 pub(crate) tokens: TokenMap,
1269 pub(crate) next_id: u32,
1270 total_length: usize,
1271}
1272
1273impl TokenSet {
1274 pub fn into_mut(self) -> Self {
1275 let tokens = match self.tokens {
1276 TokenMap::HashMap(map) => map,
1277 TokenMap::Fst(map) => {
1278 let mut new_map = HashMap::with_capacity(map.len());
1279 let mut stream = map.into_stream();
1280 while let Some((token, token_id)) = stream.next() {
1281 new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
1282 }
1283
1284 new_map
1285 }
1286 };
1287
1288 Self {
1289 tokens: TokenMap::HashMap(tokens),
1290 next_id: self.next_id,
1291 total_length: self.total_length,
1292 }
1293 }
1294
1295 pub fn len(&self) -> usize {
1296 self.tokens.len()
1297 }
1298
1299 pub fn is_empty(&self) -> bool {
1300 self.len() == 0
1301 }
1302
1303 pub fn to_batch(self, format: TokenSetFormat) -> Result<RecordBatch> {
1304 match format {
1305 TokenSetFormat::Arrow => self.into_arrow_batch(),
1306 TokenSetFormat::Fst => self.into_fst_batch(),
1307 }
1308 }
1309
1310 fn into_arrow_batch(self) -> Result<RecordBatch> {
1311 let mut token_builder = StringBuilder::with_capacity(self.tokens.len(), self.total_length);
1312 let mut token_id_builder = UInt32Builder::with_capacity(self.tokens.len());
1313
1314 match self.tokens {
1315 TokenMap::Fst(map) => {
1316 let mut stream = map.stream();
1317 while let Some((token, token_id)) = stream.next() {
1318 token_builder.append_value(String::from_utf8_lossy(token));
1319 token_id_builder.append_value(token_id as u32);
1320 }
1321 }
1322 TokenMap::HashMap(map) => {
1323 for (token, token_id) in map.into_iter().sorted_unstable() {
1324 token_builder.append_value(token);
1325 token_id_builder.append_value(token_id);
1326 }
1327 }
1328 }
1329
1330 let token_col = token_builder.finish();
1331 let token_id_col = token_id_builder.finish();
1332
1333 let schema = arrow_schema::Schema::new(vec![
1334 arrow_schema::Field::new(TOKEN_COL, DataType::Utf8, false),
1335 arrow_schema::Field::new(TOKEN_ID_COL, DataType::UInt32, false),
1336 ]);
1337
1338 let batch = RecordBatch::try_new(
1339 Arc::new(schema),
1340 vec![
1341 Arc::new(token_col) as ArrayRef,
1342 Arc::new(token_id_col) as ArrayRef,
1343 ],
1344 )?;
1345 Ok(batch)
1346 }
1347
1348 fn into_fst_batch(mut self) -> Result<RecordBatch> {
1349 let fst_map = match std::mem::take(&mut self.tokens) {
1350 TokenMap::Fst(map) => map,
1351 TokenMap::HashMap(map) => Self::build_fst_from_map(map)?,
1352 };
1353 let bytes = fst_map.into_fst().into_inner();
1354
1355 let mut fst_builder = LargeBinaryBuilder::with_capacity(1, bytes.len());
1356 fst_builder.append_value(bytes);
1357 let fst_col = fst_builder.finish();
1358
1359 let mut next_id_builder = UInt32Builder::with_capacity(1);
1360 next_id_builder.append_value(self.next_id);
1361 let next_id_col = next_id_builder.finish();
1362
1363 let mut total_length_builder = UInt64Builder::with_capacity(1);
1364 total_length_builder.append_value(self.total_length as u64);
1365 let total_length_col = total_length_builder.finish();
1366
1367 let schema = arrow_schema::Schema::new(vec![
1368 arrow_schema::Field::new(TOKEN_FST_BYTES_COL, DataType::LargeBinary, false),
1369 arrow_schema::Field::new(TOKEN_NEXT_ID_COL, DataType::UInt32, false),
1370 arrow_schema::Field::new(TOKEN_TOTAL_LENGTH_COL, DataType::UInt64, false),
1371 ]);
1372
1373 let batch = RecordBatch::try_new(
1374 Arc::new(schema),
1375 vec![
1376 Arc::new(fst_col) as ArrayRef,
1377 Arc::new(next_id_col) as ArrayRef,
1378 Arc::new(total_length_col) as ArrayRef,
1379 ],
1380 )?;
1381 Ok(batch)
1382 }
1383
1384 fn build_fst_from_map(map: HashMap<String, u32>) -> Result<fst::Map<Vec<u8>>> {
1385 let mut entries: Vec<_> = map.into_iter().collect();
1386 entries.sort_unstable_by(|(lhs, _), (rhs, _)| lhs.cmp(rhs));
1387 let mut builder = fst::MapBuilder::memory();
1388 for (token, token_id) in entries {
1389 builder
1390 .insert(&token, token_id as u64)
1391 .map_err(|e| Error::index(format!("failed to insert token {}: {}", token, e)))?;
1392 }
1393 Ok(builder.into_map())
1394 }
1395
1396 pub async fn load(reader: Arc<dyn IndexReader>, format: TokenSetFormat) -> Result<Self> {
1397 match format {
1398 TokenSetFormat::Arrow => Self::load_arrow(reader).await,
1399 TokenSetFormat::Fst => Self::load_fst(reader).await,
1400 }
1401 }
1402
1403 async fn load_arrow(reader: Arc<dyn IndexReader>) -> Result<Self> {
1404 let batch = reader.read_range(0..reader.num_rows(), None).await?;
1405
1406 let (tokens, next_id, total_length) = spawn_blocking(move || {
1407 let mut next_id = 0;
1408 let mut total_length = 0;
1409 let mut tokens = fst::MapBuilder::memory();
1410
1411 let token_col = batch[TOKEN_COL].as_string::<i32>();
1412 let token_id_col = batch[TOKEN_ID_COL].as_primitive::<datatypes::UInt32Type>();
1413
1414 for (token, &token_id) in token_col.iter().zip(token_id_col.values().iter()) {
1415 let token =
1416 token.ok_or(Error::index("found null token in token set".to_owned()))?;
1417 next_id = next_id.max(token_id + 1);
1418 total_length += token.len();
1419 tokens.insert(token, token_id as u64).map_err(|e| {
1420 Error::index(format!("failed to insert token {}: {}", token, e))
1421 })?;
1422 }
1423
1424 Ok::<_, Error>((tokens.into_map(), next_id, total_length))
1425 })
1426 .await
1427 .map_err(|err| Error::execution(format!("failed to spawn blocking task: {}", err)))??;
1428
1429 Ok(Self {
1430 tokens: TokenMap::Fst(tokens),
1431 next_id,
1432 total_length,
1433 })
1434 }
1435
1436 async fn load_fst(reader: Arc<dyn IndexReader>) -> Result<Self> {
1437 let batch = reader.read_range(0..reader.num_rows(), None).await?;
1438 if batch.num_rows() == 0 {
1439 return Err(Error::index("token set batch is empty".to_owned()));
1440 }
1441
1442 let fst_col = batch[TOKEN_FST_BYTES_COL].as_binary::<i64>();
1443 let bytes = fst_col.value(0);
1444 let map = fst::Map::new(bytes.to_vec())
1445 .map_err(|e| Error::index(format!("failed to load fst tokens: {}", e)))?;
1446
1447 let next_id_col = batch[TOKEN_NEXT_ID_COL].as_primitive::<datatypes::UInt32Type>();
1448 let total_length_col =
1449 batch[TOKEN_TOTAL_LENGTH_COL].as_primitive::<datatypes::UInt64Type>();
1450
1451 let next_id = next_id_col
1452 .values()
1453 .first()
1454 .copied()
1455 .ok_or(Error::index("token next id column is empty".to_owned()))?;
1456
1457 let total_length = total_length_col
1458 .values()
1459 .first()
1460 .copied()
1461 .ok_or(Error::index(
1462 "token total length column is empty".to_owned(),
1463 ))?;
1464
1465 Ok(Self {
1466 tokens: TokenMap::Fst(map),
1467 next_id,
1468 total_length: usize::try_from(total_length).map_err(|_| {
1469 Error::index(format!(
1470 "token total length {} overflows usize",
1471 total_length
1472 ))
1473 })?,
1474 })
1475 }
1476
1477 pub fn add(&mut self, token: String) -> u32 {
1478 let next_id = self.next_id();
1479 let len = token.len();
1480 let token_id = match self.tokens {
1481 TokenMap::HashMap(ref mut map) => *map.entry(token).or_insert(next_id),
1482 _ => unreachable!("tokens must be HashMap while indexing"),
1483 };
1484
1485 if token_id == next_id {
1487 self.next_id += 1;
1488 self.total_length += len;
1489 }
1490
1491 token_id
1492 }
1493
1494 pub(crate) fn get_or_add(&mut self, token: &str) -> u32 {
1495 let next_id = self.next_id;
1496 match self.tokens {
1497 TokenMap::HashMap(ref mut map) => {
1498 if let Some(&token_id) = map.get(token) {
1499 return token_id;
1500 }
1501
1502 map.insert(token.to_owned(), next_id);
1503 }
1504 _ => unreachable!("tokens must be HashMap while indexing"),
1505 }
1506
1507 self.next_id += 1;
1508 self.total_length += token.len();
1509 next_id
1510 }
1511
1512 pub(crate) fn into_mutable(self) -> Self {
1513 let Self {
1514 tokens,
1515 next_id,
1516 total_length,
1517 } = self;
1518 match tokens {
1519 TokenMap::HashMap(_) => Self {
1520 tokens,
1521 next_id,
1522 total_length,
1523 },
1524 TokenMap::Fst(map) => {
1525 let mut mutable = HashMap::new();
1526 let mut stream = map.stream();
1527 while let Some((token, token_id)) = stream.next() {
1528 mutable.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
1529 }
1530 Self {
1531 tokens: TokenMap::HashMap(mutable),
1532 next_id,
1533 total_length,
1534 }
1535 }
1536 }
1537 }
1538
1539 pub fn get(&self, token: &str) -> Option<u32> {
1540 match self.tokens {
1541 TokenMap::HashMap(ref map) => map.get(token).copied(),
1542 TokenMap::Fst(ref map) => map.get(token).map(|id| id as u32),
1543 }
1544 }
1545
1546 pub fn remap(&mut self, removed_token_ids: &[u32]) {
1548 if removed_token_ids.is_empty() {
1549 return;
1550 }
1551
1552 let mut map = match std::mem::take(&mut self.tokens) {
1553 TokenMap::HashMap(map) => map,
1554 TokenMap::Fst(map) => {
1555 let mut new_map = HashMap::with_capacity(map.len());
1556 let mut stream = map.into_stream();
1557 while let Some((token, token_id)) = stream.next() {
1558 new_map.insert(String::from_utf8_lossy(token).into_owned(), token_id as u32);
1559 }
1560
1561 new_map
1562 }
1563 };
1564
1565 map.retain(
1566 |_, token_id| match removed_token_ids.binary_search(token_id) {
1567 Ok(_) => false,
1568 Err(index) => {
1569 *token_id -= index as u32;
1570 true
1571 }
1572 },
1573 );
1574
1575 self.tokens = TokenMap::HashMap(map);
1576 }
1577
1578 pub fn next_id(&self) -> u32 {
1579 self.next_id
1580 }
1581
1582 pub(crate) fn memory_size(&self) -> usize {
1583 match &self.tokens {
1584 TokenMap::HashMap(map) => {
1585 self.total_length
1586 + map.capacity()
1587 * (std::mem::size_of::<String>()
1588 + std::mem::size_of::<u32>()
1589 + std::mem::size_of::<usize>())
1590 }
1591 TokenMap::Fst(map) => map.as_fst().size(),
1592 }
1593 }
1594}
1595
1596pub struct PostingListReader {
1597 reader: Arc<dyn IndexReader>,
1598
1599 offsets: Option<Vec<usize>>,
1601
1602 max_scores: Option<Vec<f32>>,
1605
1606 lengths: Option<Vec<u32>>,
1608
1609 has_position: bool,
1610 posting_tail_codec: PostingTailCodec,
1611 positions_layout: PositionsLayout,
1612
1613 index_cache: WeakLanceCache,
1614}
1615
1616#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1617enum PositionsLayout {
1618 None,
1619 LegacyPerDoc,
1620 SharedStream(PositionStreamCodec),
1621}
1622
1623impl std::fmt::Debug for PostingListReader {
1624 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1625 f.debug_struct("InvertedListReader")
1626 .field("offsets", &self.offsets)
1627 .field("max_scores", &self.max_scores)
1628 .finish()
1629 }
1630}
1631
1632impl DeepSizeOf for PostingListReader {
1633 fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
1634 self.offsets.deep_size_of_children(context)
1635 + self.max_scores.deep_size_of_children(context)
1636 + self.lengths.deep_size_of_children(context)
1637 }
1638}
1639
1640impl PostingListReader {
1641 pub(crate) async fn try_new(
1642 reader: Arc<dyn IndexReader>,
1643 index_cache: &LanceCache,
1644 ) -> Result<Self> {
1645 let positions_layout = if reader.schema().field(COMPRESSED_POSITION_COL).is_some() {
1646 PositionsLayout::SharedStream(parse_shared_position_codec(&reader.schema().metadata)?)
1647 } else if reader.schema().field(POSITION_COL).is_some() {
1648 PositionsLayout::LegacyPerDoc
1649 } else {
1650 PositionsLayout::None
1651 };
1652 let posting_tail_codec = parse_posting_tail_codec(&reader.schema().metadata)?;
1653 let has_position = positions_layout != PositionsLayout::None;
1654 let (offsets, max_scores, lengths) = if reader.schema().field(POSTING_COL).is_none() {
1655 let (offsets, max_scores) = Self::load_metadata(reader.schema())?;
1656 (Some(offsets), max_scores, None)
1657 } else {
1658 let metadata = reader
1659 .read_range(0..reader.num_rows(), Some(&[MAX_SCORE_COL, LENGTH_COL]))
1660 .await?;
1661 let max_scores = metadata[MAX_SCORE_COL]
1662 .as_primitive::<Float32Type>()
1663 .values()
1664 .to_vec();
1665 let lengths = metadata[LENGTH_COL]
1666 .as_primitive::<UInt32Type>()
1667 .values()
1668 .to_vec();
1669 (None, Some(max_scores), Some(lengths))
1670 };
1671
1672 Ok(Self {
1673 reader,
1674 offsets,
1675 max_scores,
1676 lengths,
1677 has_position,
1678 posting_tail_codec,
1679 positions_layout,
1680 index_cache: WeakLanceCache::from(index_cache),
1681 })
1682 }
1683
1684 fn load_metadata(
1687 schema: &lance_core::datatypes::Schema,
1688 ) -> Result<(Vec<usize>, Option<Vec<f32>>)> {
1689 let offsets = schema
1690 .metadata
1691 .get("offsets")
1692 .ok_or(Error::index("offsets not found in metadata".to_owned()))?;
1693 let offsets = serde_json::from_str(offsets)?;
1694
1695 let max_scores = schema
1696 .metadata
1697 .get("max_scores")
1698 .map(|max_scores| serde_json::from_str(max_scores))
1699 .transpose()?;
1700 Ok((offsets, max_scores))
1701 }
1702
1703 pub fn len(&self) -> usize {
1705 match self.offsets {
1706 Some(ref offsets) => offsets.len(),
1707 None => self.reader.num_rows(),
1708 }
1709 }
1710
1711 pub fn is_empty(&self) -> bool {
1712 self.len() == 0
1713 }
1714
1715 pub(crate) fn has_positions(&self) -> bool {
1716 self.has_position
1717 }
1718
1719 pub(crate) fn posting_tail_codec(&self) -> PostingTailCodec {
1720 self.posting_tail_codec
1721 }
1722
1723 pub(crate) fn posting_len(&self, token_id: u32) -> usize {
1724 let token_id = token_id as usize;
1725
1726 match self.offsets {
1727 Some(ref offsets) => {
1728 let next_offset = offsets
1729 .get(token_id + 1)
1730 .copied()
1731 .unwrap_or(self.reader.num_rows());
1732 next_offset - offsets[token_id]
1733 }
1734 None => {
1735 if let Some(lengths) = &self.lengths {
1736 lengths[token_id] as usize
1737 } else {
1738 panic!("posting list reader is not initialized")
1739 }
1740 }
1741 }
1742 }
1743
1744 pub(crate) async fn posting_batch(
1745 &self,
1746 token_id: u32,
1747 with_position: bool,
1748 ) -> Result<RecordBatch> {
1749 if self.offsets.is_some() {
1750 self.posting_batch_legacy(token_id, with_position).await
1751 } else {
1752 let token_id = token_id as usize;
1753 let columns = if with_position {
1754 match self.positions_layout {
1755 PositionsLayout::SharedStream(_) => {
1756 vec![
1757 POSTING_COL,
1758 COMPRESSED_POSITION_COL,
1759 POSITION_BLOCK_OFFSET_COL,
1760 ]
1761 }
1762 PositionsLayout::LegacyPerDoc => vec![POSTING_COL, POSITION_COL],
1763 PositionsLayout::None => vec![POSTING_COL],
1764 }
1765 } else {
1766 vec![POSTING_COL]
1767 };
1768 let batch = self
1769 .reader
1770 .read_range(token_id..token_id + 1, Some(&columns))
1771 .await?;
1772 Ok(batch)
1773 }
1774 }
1775
1776 async fn posting_batch_legacy(
1777 &self,
1778 token_id: u32,
1779 with_position: bool,
1780 ) -> Result<RecordBatch> {
1781 let mut columns = vec![ROW_ID, FREQUENCY_COL];
1782 if with_position {
1783 columns.push(POSITION_COL);
1784 }
1785
1786 let length = self.posting_len(token_id);
1787 let token_id = token_id as usize;
1788 let offset = self.offsets.as_ref().unwrap()[token_id];
1789 let batch = self
1790 .reader
1791 .read_range(offset..offset + length, Some(&columns))
1792 .await?;
1793 Ok(batch)
1794 }
1795
1796 #[instrument(level = "debug", skip(self, metrics))]
1797 pub(crate) async fn posting_list(
1798 &self,
1799 token_id: u32,
1800 is_phrase_query: bool,
1801 metrics: &dyn MetricsCollector,
1802 ) -> Result<PostingList> {
1803 let cache_key = PostingListKey { token_id };
1804 let mut posting = self
1805 .index_cache
1806 .get_or_insert_with_key(cache_key, || async move {
1807 metrics.record_part_load();
1808 info!(target: TRACE_IO_EVENTS, r#type=IO_TYPE_LOAD_SCALAR_PART, index_type="inverted", part_id=token_id);
1809 let batch = self.posting_batch(token_id, false).await?;
1810 self.posting_list_from_batch(&batch, token_id)
1811 })
1812 .await?
1813 .as_ref()
1814 .clone();
1815
1816 if is_phrase_query && !posting.has_position() {
1817 let positions = self.read_positions(token_id).await?;
1819 posting.set_positions(positions);
1820 }
1821
1822 Ok(posting)
1823 }
1824
1825 fn posting_list_from_batch_parts(
1826 batch: &RecordBatch,
1827 max_score: Option<f32>,
1828 length: Option<u32>,
1829 posting_tail_codec: PostingTailCodec,
1830 positions_layout: PositionsLayout,
1831 ) -> Result<PostingList> {
1832 let posting_list = PostingList::from_batch_with_tail_codec_and_positions_layout(
1833 batch,
1834 max_score,
1835 length,
1836 posting_tail_codec,
1837 positions_layout,
1838 )?;
1839 Ok(posting_list)
1840 }
1841
1842 pub(crate) fn posting_list_from_batch(
1843 &self,
1844 batch: &RecordBatch,
1845 token_id: u32,
1846 ) -> Result<PostingList> {
1847 Self::posting_list_from_batch_parts(
1848 batch,
1849 self.max_scores
1850 .as_ref()
1851 .map(|max_scores| max_scores[token_id as usize]),
1852 self.lengths
1853 .as_ref()
1854 .map(|lengths| lengths[token_id as usize]),
1855 self.posting_tail_codec,
1856 self.positions_layout,
1857 )
1858 }
1859
1860 fn build_prewarm_posting_lists(
1861 batch: RecordBatch,
1862 offsets: Option<Vec<usize>>,
1863 max_scores: Option<Vec<f32>>,
1864 lengths: Option<Vec<u32>>,
1865 posting_tail_codec: PostingTailCodec,
1866 positions_layout: PositionsLayout,
1867 ) -> Result<Vec<(u32, PostingList)>> {
1868 let token_count = if let Some(offsets) = offsets.as_ref() {
1869 offsets.len()
1870 } else if let Some(lengths) = lengths.as_ref() {
1871 lengths.len()
1872 } else {
1873 batch.num_rows()
1874 };
1875
1876 let mut posting_lists = Vec::with_capacity(token_count);
1877 for token_id in 0..token_count {
1878 let batch = if let Some(offsets) = offsets.as_ref() {
1879 let start = offsets[token_id];
1880 let end = if token_id + 1 < offsets.len() {
1881 offsets[token_id + 1]
1882 } else {
1883 batch.num_rows()
1884 };
1885 batch.slice(start, end - start)
1886 } else {
1887 batch.slice(token_id, 1)
1888 };
1889 let batch = batch.shrink_to_fit()?;
1890 let posting_list = Self::posting_list_from_batch_parts(
1891 &batch,
1892 max_scores.as_ref().map(|scores| scores[token_id]),
1893 lengths.as_ref().map(|lengths| lengths[token_id]),
1894 posting_tail_codec,
1895 positions_layout,
1896 )?;
1897 posting_lists.push((token_id as u32, posting_list));
1898 }
1899
1900 Ok(posting_lists)
1901 }
1902
1903 async fn prewarm_posting_lists(&self, with_position: bool) -> Result<()> {
1904 if with_position && !self.has_positions() {
1905 return Err(Error::invalid_input(
1906 "cannot prewarm positions for an inverted index that was built without positions; recreate the index with with_position=true".to_owned(),
1907 ));
1908 }
1909
1910 let read_batch_start = Instant::now();
1911 let batch = self.read_batch(with_position).await?;
1912 let read_batch_elapsed = read_batch_start.elapsed();
1913
1914 let legacy_layout = self.offsets.is_some();
1915 let offsets = self.offsets.clone();
1916 let max_scores = self.max_scores.clone();
1917 let lengths = self.lengths.clone();
1918 let posting_tail_codec = self.posting_tail_codec;
1919 let positions_layout = self.positions_layout;
1920 let populate_start = Instant::now();
1921 let posting_lists = spawn_blocking(move || {
1922 Self::build_prewarm_posting_lists(
1923 batch,
1924 offsets,
1925 max_scores,
1926 lengths,
1927 posting_tail_codec,
1928 positions_layout,
1929 )
1930 })
1931 .await
1932 .map_err(|err| {
1933 Error::internal(format!(
1934 "Failed to build prewarm posting lists in blocking task: {err}"
1935 ))
1936 })??;
1937 for (token_id, mut posting_list) in posting_lists {
1938 if with_position && let Some(positions) = posting_list.take_positions() {
1939 self.index_cache
1940 .insert_with_key(&PositionKey { token_id }, Arc::new(Positions(positions)))
1941 .await;
1942 }
1943 self.index_cache
1944 .insert_with_key(&PostingListKey { token_id }, Arc::new(posting_list))
1945 .await;
1946 }
1947 let populate_elapsed = populate_start.elapsed();
1948
1949 info!(
1950 legacy_layout,
1951 with_position,
1952 token_count = self.len(),
1953 read_batch_ms = read_batch_elapsed.as_secs_f64() * 1000.0,
1954 post_read_loop_ms = populate_elapsed.as_secs_f64() * 1000.0,
1955 "posting list prewarm timing"
1956 );
1957
1958 Ok(())
1959 }
1960
1961 pub(crate) async fn read_batch(&self, with_position: bool) -> Result<RecordBatch> {
1962 let columns = self.posting_columns(with_position);
1963 let batch = self
1964 .reader
1965 .read_range(0..self.reader.num_rows(), Some(&columns))
1966 .await?;
1967 Ok(batch)
1968 }
1969
1970 pub(crate) async fn read_all(
1971 &self,
1972 with_position: bool,
1973 ) -> Result<impl Iterator<Item = Result<PostingList>> + '_> {
1974 let batch = self.read_batch(with_position).await?;
1975 Ok((0..self.len()).map(move |i| {
1976 let token_id = i as u32;
1977 let range = self.posting_list_range(token_id);
1978 let batch = batch.slice(i, range.end - range.start);
1979 self.posting_list_from_batch(&batch, token_id)
1980 }))
1981 }
1982
1983 async fn read_positions(&self, token_id: u32) -> Result<CompressedPositionStorage> {
1984 let positions = self.index_cache.get_or_insert_with_key(PositionKey { token_id }, || async move {
1985 let positions = match self.positions_layout {
1986 PositionsLayout::None => {
1987 return Err(Error::invalid_input(
1988 "position is not found but required for phrase queries, try recreating the index with position".to_owned(),
1989 ));
1990 }
1991 PositionsLayout::LegacyPerDoc => {
1992 let batch = self
1993 .reader
1994 .read_range(self.posting_list_range(token_id), Some(&[POSITION_COL]))
1995 .await
1996 .map_err(|e| match e {
1997 Error::Schema { .. } => Error::invalid_input("position is not found but required for phrase queries, try recreating the index with position".to_owned()),
1998 e => e,
1999 })?;
2000 CompressedPositionStorage::LegacyPerDoc(
2001 batch[POSITION_COL].as_list::<i32>().value(0).as_list::<i32>().clone(),
2002 )
2003 }
2004 PositionsLayout::SharedStream(codec) => {
2005 let batch = self
2006 .reader
2007 .read_range(
2008 self.posting_list_range(token_id),
2009 Some(&[COMPRESSED_POSITION_COL, POSITION_BLOCK_OFFSET_COL]),
2010 )
2011 .await
2012 .map_err(|e| match e {
2013 Error::Schema { .. } => Error::invalid_input("position is not found but required for phrase queries, try recreating the index with position".to_owned()),
2014 e => e,
2015 })?;
2016 let bytes = bytes::Bytes::from(
2017 batch[COMPRESSED_POSITION_COL]
2018 .as_binary::<i64>()
2019 .value(0)
2020 .to_vec(),
2021 );
2022 let block_offsets = batch[POSITION_BLOCK_OFFSET_COL]
2023 .as_list::<i32>()
2024 .value(0)
2025 .as_primitive::<UInt32Type>()
2026 .values()
2027 .to_vec();
2028 CompressedPositionStorage::SharedStream(SharedPositionStream::new(
2029 codec,
2030 block_offsets,
2031 bytes,
2032 ))
2033 }
2034 };
2035 Result::Ok(Positions(positions))
2036 }).await?;
2037 Ok(positions.0.clone())
2038 }
2039
2040 fn posting_list_range(&self, token_id: u32) -> Range<usize> {
2041 match self.offsets {
2042 Some(ref offsets) => {
2043 let offset = offsets[token_id as usize];
2044 let posting_len = self.posting_len(token_id);
2045 offset..offset + posting_len
2046 }
2047 None => {
2048 let token_id = token_id as usize;
2049 token_id..token_id + 1
2050 }
2051 }
2052 }
2053
2054 fn posting_columns(&self, with_position: bool) -> Vec<&'static str> {
2055 let mut base_columns = match self.offsets {
2056 Some(_) => vec![ROW_ID, FREQUENCY_COL],
2057 None => vec![POSTING_COL],
2058 };
2059 if with_position {
2060 match self.positions_layout {
2061 PositionsLayout::None => {}
2062 PositionsLayout::LegacyPerDoc => base_columns.push(POSITION_COL),
2063 PositionsLayout::SharedStream(_) => {
2064 base_columns.push(COMPRESSED_POSITION_COL);
2065 base_columns.push(POSITION_BLOCK_OFFSET_COL);
2066 }
2067 }
2068 }
2069 base_columns
2070 }
2071}
2072
2073#[derive(Clone)]
2076pub struct Positions(pub(super) CompressedPositionStorage);
2077
2078impl DeepSizeOf for Positions {
2079 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
2080 match &self.0 {
2081 CompressedPositionStorage::LegacyPerDoc(positions) => {
2082 positions.get_buffer_memory_size()
2083 }
2084 CompressedPositionStorage::SharedStream(stream) => stream.size(),
2085 }
2086 }
2087}
2088
2089#[derive(Debug, Clone)]
2091pub struct PostingListKey {
2092 pub token_id: u32,
2093}
2094
2095impl CacheKey for PostingListKey {
2096 type ValueType = PostingList;
2097
2098 fn key(&self) -> std::borrow::Cow<'_, str> {
2099 format!("postings-{}", self.token_id).into()
2100 }
2101
2102 fn type_name() -> &'static str {
2103 "PostingList"
2104 }
2105
2106 fn codec() -> Option<CacheCodec> {
2107 Some(CacheCodec::from_impl::<PostingList>())
2108 }
2109}
2110
2111#[derive(Debug, Clone)]
2112pub struct PositionKey {
2113 pub token_id: u32,
2114}
2115
2116impl CacheKey for PositionKey {
2117 type ValueType = Positions;
2118
2119 fn key(&self) -> std::borrow::Cow<'_, str> {
2120 format!("positions-{}", self.token_id).into()
2121 }
2122
2123 fn type_name() -> &'static str {
2124 "Position"
2125 }
2126
2127 fn codec() -> Option<CacheCodec> {
2128 Some(CacheCodec::from_impl::<Positions>())
2129 }
2130}
2131
2132#[derive(Debug, Clone, PartialEq)]
2133pub enum CompressedPositionStorage {
2134 LegacyPerDoc(ListArray),
2135 SharedStream(SharedPositionStream),
2136}
2137
2138impl DeepSizeOf for CompressedPositionStorage {
2139 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
2140 match self {
2141 Self::LegacyPerDoc(positions) => positions.get_buffer_memory_size(),
2142 Self::SharedStream(stream) => stream.size(),
2143 }
2144 }
2145}
2146
2147#[derive(Debug, Clone, PartialEq, Eq, Default)]
2148pub struct SharedPositionStream {
2149 codec: PositionStreamCodec,
2150 block_offsets: Vec<u32>,
2151 bytes: bytes::Bytes,
2155}
2156
2157impl SharedPositionStream {
2158 pub fn new(codec: PositionStreamCodec, block_offsets: Vec<u32>, bytes: bytes::Bytes) -> Self {
2159 Self {
2160 codec,
2161 block_offsets,
2162 bytes,
2163 }
2164 }
2165
2166 pub fn codec(&self) -> PositionStreamCodec {
2167 self.codec
2168 }
2169
2170 pub fn block_count(&self) -> usize {
2171 self.block_offsets.len()
2172 }
2173
2174 pub fn block_range(&self, index: usize) -> Range<usize> {
2175 let start = self.block_offsets[index] as usize;
2176 let end = self
2177 .block_offsets
2178 .get(index + 1)
2179 .map(|offset| *offset as usize)
2180 .unwrap_or(self.bytes.len());
2181 start..end
2182 }
2183
2184 pub fn block(&self, index: usize) -> &[u8] {
2185 let range = self.block_range(index);
2186 &self.bytes[range]
2187 }
2188
2189 pub fn bytes(&self) -> &[u8] {
2190 &self.bytes
2191 }
2192
2193 pub fn block_offsets(&self) -> &[u32] {
2194 &self.block_offsets
2195 }
2196
2197 pub fn size(&self) -> usize {
2198 self.block_offsets.capacity() * std::mem::size_of::<u32>() + self.bytes.len()
2199 }
2200}
2201
2202#[derive(Debug, Clone, DeepSizeOf)]
2203pub enum PostingList {
2204 Plain(PlainPostingList),
2205 Compressed(CompressedPostingList),
2206}
2207
2208impl PostingList {
2209 pub fn from_batch(
2210 batch: &RecordBatch,
2211 max_score: Option<f32>,
2212 length: Option<u32>,
2213 ) -> Result<Self> {
2214 let posting_tail_codec = parse_posting_tail_codec(batch.schema_ref().metadata())?;
2215 Self::from_batch_with_tail_codec(batch, max_score, length, posting_tail_codec)
2216 }
2217
2218 pub fn from_batch_with_tail_codec(
2219 batch: &RecordBatch,
2220 max_score: Option<f32>,
2221 length: Option<u32>,
2222 posting_tail_codec: PostingTailCodec,
2223 ) -> Result<Self> {
2224 let positions_layout = if batch.column_by_name(COMPRESSED_POSITION_COL).is_some() {
2225 PositionsLayout::SharedStream(parse_shared_position_codec(
2226 batch.schema_ref().metadata(),
2227 )?)
2228 } else if batch.column_by_name(POSITION_COL).is_some() {
2229 PositionsLayout::LegacyPerDoc
2230 } else {
2231 PositionsLayout::None
2232 };
2233 Self::from_batch_with_tail_codec_and_positions_layout(
2234 batch,
2235 max_score,
2236 length,
2237 posting_tail_codec,
2238 positions_layout,
2239 )
2240 }
2241
2242 fn from_batch_with_tail_codec_and_positions_layout(
2243 batch: &RecordBatch,
2244 max_score: Option<f32>,
2245 length: Option<u32>,
2246 posting_tail_codec: PostingTailCodec,
2247 positions_layout: PositionsLayout,
2248 ) -> Result<Self> {
2249 match batch.column_by_name(POSTING_COL) {
2250 Some(_) => {
2251 debug_assert!(max_score.is_some() && length.is_some());
2252 let shared_position_codec = match positions_layout {
2253 PositionsLayout::SharedStream(codec) => Some(codec),
2254 _ => None,
2255 };
2256 let posting = CompressedPostingList::from_batch(
2257 batch,
2258 max_score.unwrap(),
2259 length.unwrap(),
2260 posting_tail_codec,
2261 shared_position_codec,
2262 );
2263 Ok(Self::Compressed(posting))
2264 }
2265 None => {
2266 let posting = PlainPostingList::from_batch(batch, max_score);
2267 Ok(Self::Plain(posting))
2268 }
2269 }
2270 }
2271
2272 pub fn iter(&self) -> PostingListIterator<'_> {
2273 PostingListIterator::new(self)
2274 }
2275
2276 pub fn has_position(&self) -> bool {
2277 match self {
2278 Self::Plain(posting) => posting.positions.is_some(),
2279 Self::Compressed(posting) => posting.positions.is_some(),
2280 }
2281 }
2282
2283 pub fn set_positions(&mut self, positions: CompressedPositionStorage) {
2284 match self {
2285 Self::Plain(posting) => match positions {
2286 CompressedPositionStorage::LegacyPerDoc(positions) => {
2287 posting.positions = Some(positions)
2288 }
2289 CompressedPositionStorage::SharedStream(_) => {
2290 unreachable!("shared position stream is not supported for plain postings")
2291 }
2292 },
2293 Self::Compressed(posting) => {
2294 posting.positions = Some(positions);
2295 }
2296 }
2297 }
2298
2299 pub fn take_positions(&mut self) -> Option<CompressedPositionStorage> {
2300 match self {
2301 Self::Plain(posting) => posting
2302 .positions
2303 .take()
2304 .map(CompressedPositionStorage::LegacyPerDoc),
2305 Self::Compressed(posting) => posting.positions.take(),
2306 }
2307 }
2308
2309 pub fn max_score(&self) -> Option<f32> {
2310 match self {
2311 Self::Plain(posting) => posting.max_score,
2312 Self::Compressed(posting) => Some(posting.max_score),
2313 }
2314 }
2315
2316 pub fn len(&self) -> usize {
2317 match self {
2318 Self::Plain(posting) => posting.len(),
2319 Self::Compressed(posting) => posting.length as usize,
2320 }
2321 }
2322
2323 pub fn is_empty(&self) -> bool {
2324 self.len() == 0
2325 }
2326
2327 pub fn into_builder(self, docs: &DocSet) -> PostingListBuilder {
2328 let posting_tail_codec = match &self {
2329 Self::Plain(_) => PostingTailCodec::Fixed32,
2330 Self::Compressed(posting) => posting.posting_tail_codec,
2331 };
2332 let mut builder = PostingListBuilder::new_with_posting_tail_codec(
2333 self.has_position(),
2334 posting_tail_codec,
2335 );
2336 match self {
2337 Self::Plain(posting) => {
2339 struct Item {
2343 doc_id: u32,
2344 positions: PositionRecorder,
2345 }
2346 let doc_ids = docs
2347 .row_ids
2348 .iter()
2349 .enumerate()
2350 .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
2351 .collect::<HashMap<_, _>>();
2352 let mut items = Vec::with_capacity(posting.len());
2353 for (row_id, freq, positions) in posting.iter() {
2354 let freq = freq as u32;
2355 let positions = match positions {
2356 Some(positions) => {
2357 PositionRecorder::Position(positions.collect::<Vec<_>>().into())
2358 }
2359 None => PositionRecorder::Count(freq),
2360 };
2361 items.push(Item {
2362 doc_id: doc_ids[&row_id],
2363 positions,
2364 });
2365 }
2366 items.sort_unstable_by_key(|item| item.doc_id);
2367 for item in items {
2368 builder.add(item.doc_id, item.positions);
2369 }
2370 }
2371 Self::Compressed(posting) => {
2372 posting.iter().for_each(|(doc_id, freq, positions)| {
2373 let positions = match positions {
2374 Some(positions) => {
2375 PositionRecorder::Position(positions.collect::<Vec<_>>().into())
2376 }
2377 None => PositionRecorder::Count(freq),
2378 };
2379 builder.add(doc_id, positions);
2380 });
2381 }
2382 }
2383 builder
2384 }
2385}
2386
2387#[derive(Debug, PartialEq, Clone)]
2388pub struct PlainPostingList {
2389 pub row_ids: ScalarBuffer<u64>,
2390 pub frequencies: ScalarBuffer<f32>,
2391 pub max_score: Option<f32>,
2392 pub positions: Option<ListArray>, }
2394
2395impl DeepSizeOf for PlainPostingList {
2396 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
2397 self.row_ids.len() * std::mem::size_of::<u64>()
2398 + self.frequencies.len() * std::mem::size_of::<u32>()
2399 + self
2400 .positions
2401 .as_ref()
2402 .map(Array::get_buffer_memory_size)
2403 .unwrap_or(0)
2404 }
2405}
2406
2407impl PlainPostingList {
2408 pub fn new(
2409 row_ids: ScalarBuffer<u64>,
2410 frequencies: ScalarBuffer<f32>,
2411 max_score: Option<f32>,
2412 positions: Option<ListArray>,
2413 ) -> Self {
2414 Self {
2415 row_ids,
2416 frequencies,
2417 max_score,
2418 positions,
2419 }
2420 }
2421
2422 pub fn from_batch(batch: &RecordBatch, max_score: Option<f32>) -> Self {
2423 let row_ids = batch[ROW_ID].as_primitive::<UInt64Type>().values().clone();
2424 let frequencies = batch[FREQUENCY_COL]
2425 .as_primitive::<Float32Type>()
2426 .values()
2427 .clone();
2428 let positions = batch
2429 .column_by_name(POSITION_COL)
2430 .map(|col| col.as_list::<i32>().clone());
2431
2432 Self::new(row_ids, frequencies, max_score, positions)
2433 }
2434
2435 pub fn len(&self) -> usize {
2436 self.row_ids.len()
2437 }
2438
2439 pub fn is_empty(&self) -> bool {
2440 self.len() == 0
2441 }
2442
2443 pub fn iter(&self) -> PlainPostingListIterator<'_> {
2444 Box::new(
2445 self.row_ids
2446 .iter()
2447 .zip(self.frequencies.iter())
2448 .enumerate()
2449 .map(|(idx, (doc_id, freq))| {
2450 (
2451 *doc_id,
2452 *freq,
2453 self.positions.as_ref().map(|p| {
2454 let start = p.value_offsets()[idx] as usize;
2455 let end = p.value_offsets()[idx + 1] as usize;
2456 Box::new(
2457 p.values().as_primitive::<Int32Type>().values()[start..end]
2458 .iter()
2459 .map(|pos| *pos as u32),
2460 ) as _
2461 }),
2462 )
2463 }),
2464 )
2465 }
2466
2467 #[inline]
2468 pub fn doc(&self, i: usize) -> LocatedDocInfo {
2469 LocatedDocInfo::new(self.row_ids[i], self.frequencies[i])
2470 }
2471
2472 pub fn positions(&self, index: usize) -> Option<Arc<dyn Array>> {
2473 self.positions
2474 .as_ref()
2475 .map(|positions| positions.value(index))
2476 }
2477
2478 pub fn max_score(&self) -> Option<f32> {
2479 self.max_score
2480 }
2481
2482 pub fn row_id(&self, i: usize) -> u64 {
2483 self.row_ids[i]
2484 }
2485}
2486
2487#[derive(Debug, PartialEq, Clone)]
2488pub struct CompressedPostingList {
2489 pub max_score: f32,
2490 pub length: u32,
2491 pub blocks: LargeBinaryArray,
2494 pub posting_tail_codec: PostingTailCodec,
2495 pub positions: Option<CompressedPositionStorage>,
2496}
2497
2498impl DeepSizeOf for CompressedPostingList {
2499 fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
2500 self.blocks.get_buffer_memory_size()
2501 + self
2502 .positions
2503 .as_ref()
2504 .map(|positions| match positions {
2505 CompressedPositionStorage::LegacyPerDoc(positions) => {
2506 positions.get_buffer_memory_size()
2507 }
2508 CompressedPositionStorage::SharedStream(stream) => stream.size(),
2509 })
2510 .unwrap_or(0)
2511 }
2512}
2513
2514impl CompressedPostingList {
2515 pub fn new(
2516 blocks: LargeBinaryArray,
2517 max_score: f32,
2518 length: u32,
2519 posting_tail_codec: PostingTailCodec,
2520 positions: Option<CompressedPositionStorage>,
2521 ) -> Self {
2522 Self {
2523 max_score,
2524 length,
2525 blocks,
2526 posting_tail_codec,
2527 positions,
2528 }
2529 }
2530
2531 pub fn from_batch(
2532 batch: &RecordBatch,
2533 max_score: f32,
2534 length: u32,
2535 posting_tail_codec: PostingTailCodec,
2536 shared_position_codec: Option<PositionStreamCodec>,
2537 ) -> Self {
2538 debug_assert_eq!(batch.num_rows(), 1);
2539 let blocks = batch[POSTING_COL]
2540 .as_list::<i32>()
2541 .value(0)
2542 .as_binary::<i64>()
2543 .clone();
2544 let positions = if let Some(col) = batch.column_by_name(COMPRESSED_POSITION_COL) {
2545 let bytes = bytes::Bytes::from(col.as_binary::<i64>().value(0).to_vec());
2546 let block_offsets = batch[POSITION_BLOCK_OFFSET_COL]
2547 .as_list::<i32>()
2548 .value(0)
2549 .as_primitive::<UInt32Type>()
2550 .values()
2551 .to_vec();
2552 let codec = shared_position_codec.unwrap_or_else(|| {
2553 parse_shared_position_codec(batch.schema_ref().metadata())
2554 .expect("shared position stream codec metadata should be valid")
2555 });
2556 Some(CompressedPositionStorage::SharedStream(
2557 SharedPositionStream::new(codec, block_offsets, bytes),
2558 ))
2559 } else {
2560 batch.column_by_name(POSITION_COL).map(|col| {
2561 CompressedPositionStorage::LegacyPerDoc(
2562 col.as_list::<i32>().value(0).as_list::<i32>().clone(),
2563 )
2564 })
2565 };
2566
2567 Self {
2568 max_score,
2569 length,
2570 blocks,
2571 posting_tail_codec,
2572 positions,
2573 }
2574 }
2575
2576 pub fn iter(&self) -> CompressedPostingListIterator {
2577 CompressedPostingListIterator::new(
2578 self.length as usize,
2579 self.blocks.clone(),
2580 self.posting_tail_codec,
2581 self.positions.clone(),
2582 )
2583 }
2584
2585 pub fn block_max_score(&self, block_idx: usize) -> f32 {
2586 let block = self.blocks.value(block_idx);
2587 block[0..4].try_into().map(f32::from_le_bytes).unwrap()
2588 }
2589
2590 pub fn block_least_doc_id(&self, block_idx: usize) -> u32 {
2591 let block = self.blocks.value(block_idx);
2592 let remainder = self.length as usize % BLOCK_SIZE;
2593 let is_remainder_block = remainder > 0 && block_idx + 1 == self.blocks.len();
2594 if is_remainder_block {
2595 super::encoding::read_posting_tail_first_doc(block, self.posting_tail_codec)
2596 } else {
2597 block[4..8].try_into().map(u32::from_le_bytes).unwrap()
2598 }
2599 }
2600}
2601
2602#[derive(Debug, Clone, PartialEq, Eq, Default)]
2603struct EncodedBlocks {
2604 offsets: Vec<u32>,
2605 bytes: Vec<u8>,
2606}
2607
2608impl EncodedBlocks {
2609 fn len(&self) -> usize {
2610 self.offsets.len()
2611 }
2612
2613 fn size(&self) -> usize {
2614 self.offsets.capacity() * std::mem::size_of::<u32>() + self.bytes.capacity()
2615 }
2616
2617 fn push_full_block(&mut self, doc_ids: &[u32], frequencies: &[u32]) -> Result<usize> {
2618 let start = self.bytes.len();
2619 self.offsets.push(start as u32);
2620 super::encoding::encode_full_posting_block_into(doc_ids, frequencies, &mut self.bytes)?;
2621 Ok(self.bytes.len() - start)
2622 }
2623
2624 fn block(&self, index: usize) -> &[u8] {
2625 let (start, end) = self.block_range(index);
2626 &self.bytes[start..end]
2627 }
2628
2629 fn block_range(&self, index: usize) -> (usize, usize) {
2630 let start = self.offsets[index] as usize;
2631 let end = self
2632 .offsets
2633 .get(index + 1)
2634 .map(|offset| *offset as usize)
2635 .unwrap_or(self.bytes.len());
2636 (start, end)
2637 }
2638
2639 fn set_block_score(&mut self, index: usize, score: f32) {
2640 let (start, _) = self.block_range(index);
2641 self.bytes[start..start + 4].copy_from_slice(&score.to_le_bytes());
2642 }
2643
2644 fn append_remainder_block_with_codec(
2645 &mut self,
2646 doc_ids: &[u32],
2647 frequencies: &[u32],
2648 codec: PostingTailCodec,
2649 ) -> Result<()> {
2650 self.offsets.push(self.bytes.len() as u32);
2651 super::encoding::encode_remainder_posting_block_into(
2652 doc_ids,
2653 frequencies,
2654 codec,
2655 &mut self.bytes,
2656 )
2657 }
2658
2659 fn into_array(mut self) -> LargeBinaryArray {
2660 let mut offsets = Vec::with_capacity(self.offsets.len() + 1);
2661 offsets.extend(self.offsets.into_iter().map(i64::from));
2662 offsets.push(self.bytes.len() as i64);
2663 LargeBinaryArray::new(
2664 OffsetBuffer::new(ScalarBuffer::from(offsets)),
2665 Buffer::from_vec(std::mem::take(&mut self.bytes)),
2666 None,
2667 )
2668 }
2669
2670 fn iter(&self) -> impl Iterator<Item = &[u8]> {
2671 (0..self.len()).map(|index| self.block(index))
2672 }
2673}
2674
2675#[derive(Debug, Clone, PartialEq, Eq, Default)]
2676struct EncodedPositionBlocks {
2677 offsets: Vec<u32>,
2678 bytes: Vec<u8>,
2679}
2680
2681impl EncodedPositionBlocks {
2682 fn size(&self) -> usize {
2683 self.offsets.capacity() * std::mem::size_of::<u32>() + self.bytes.capacity()
2684 }
2685
2686 fn block(&self, index: usize) -> &[u8] {
2687 let start = self.offsets[index] as usize;
2688 let end = self
2689 .offsets
2690 .get(index + 1)
2691 .map(|offset| *offset as usize)
2692 .unwrap_or(self.bytes.len());
2693 &self.bytes[start..end]
2694 }
2695
2696 fn push_encoded_block(&mut self, block: &[u8]) -> usize {
2697 let start = self.bytes.len();
2698 self.offsets.push(start as u32);
2699 self.bytes.extend_from_slice(block);
2700 self.bytes.len() - start
2701 }
2702
2703 fn into_stream(self) -> SharedPositionStream {
2704 SharedPositionStream::new(
2705 PositionStreamCodec::PackedDelta,
2706 self.offsets,
2707 bytes::Bytes::from(self.bytes),
2708 )
2709 }
2710}
2711
2712#[derive(Debug)]
2713pub struct PostingListBuilder {
2714 with_positions: bool,
2715 posting_tail_codec: PostingTailCodec,
2716 encoded_blocks: Option<Box<EncodedBlocks>>,
2717 encoded_position_blocks: Option<Box<EncodedPositionBlocks>>,
2718 tail_entries: Vec<RawDocInfo>,
2719 tail_positions: PositionBlockBuilder,
2720 open_doc_id: Option<u32>,
2721 open_doc_frequency: u32,
2722 open_doc_last_position: Option<u32>,
2723 memory_size_bytes: u32,
2724 len: u32,
2725}
2726
2727pub(super) struct PostingListBatchBuilder {
2728 schema: SchemaRef,
2729 postings: ListBuilder<LargeBinaryBuilder>,
2730 max_scores: Float32Builder,
2731 lengths: UInt32Builder,
2732 positions: BatchPositionsBuilder,
2733 len: usize,
2734}
2735
2736enum BatchPositionsBuilder {
2737 None,
2738 Legacy(ListBuilder<ListBuilder<LargeBinaryBuilder>>),
2739 Shared {
2740 bytes: LargeBinaryBuilder,
2741 block_offsets: ListBuilder<UInt32Builder>,
2742 },
2743}
2744
2745struct PostingListParts<'a> {
2746 with_positions: bool,
2747 posting_tail_codec: PostingTailCodec,
2748 length: usize,
2749 encoded_blocks: EncodedBlocks,
2750 encoded_position_blocks: EncodedPositionBlocks,
2751 tail_entries: &'a [RawDocInfo],
2752 tail_position_block: Option<Vec<u8>>,
2753}
2754
2755impl PostingListBatchBuilder {
2756 pub fn new(
2757 schema: SchemaRef,
2758 with_positions: bool,
2759 format_version: InvertedListFormatVersion,
2760 capacity: usize,
2761 ) -> Self {
2762 let positions = if !with_positions {
2763 BatchPositionsBuilder::None
2764 } else if format_version.uses_shared_position_stream() {
2765 BatchPositionsBuilder::Shared {
2766 bytes: LargeBinaryBuilder::with_capacity(capacity, 0),
2767 block_offsets: ListBuilder::with_capacity(UInt32Builder::new(), capacity),
2768 }
2769 } else {
2770 BatchPositionsBuilder::Legacy(ListBuilder::with_capacity(
2771 ListBuilder::new(LargeBinaryBuilder::new()),
2772 capacity,
2773 ))
2774 };
2775 Self {
2776 schema,
2777 postings: ListBuilder::with_capacity(LargeBinaryBuilder::new(), capacity),
2778 max_scores: Float32Builder::with_capacity(capacity),
2779 lengths: UInt32Builder::with_capacity(capacity),
2780 positions,
2781 len: 0,
2782 }
2783 }
2784
2785 pub fn len(&self) -> usize {
2786 self.len
2787 }
2788
2789 pub fn is_empty(&self) -> bool {
2790 self.len == 0
2791 }
2792
2793 fn append(
2794 &mut self,
2795 compressed: LargeBinaryArray,
2796 max_score: f32,
2797 length: u32,
2798 positions: Option<&CompressedPositionStorage>,
2799 ) -> Result<()> {
2800 {
2801 let values = self.postings.values();
2802 for index in 0..compressed.len() {
2803 values.append_value(compressed.value(index));
2804 }
2805 }
2806 self.postings.append(true);
2807 self.max_scores.append_value(max_score);
2808 self.lengths.append_value(length);
2809
2810 match &mut self.positions {
2811 BatchPositionsBuilder::None => {}
2812 BatchPositionsBuilder::Shared {
2813 bytes,
2814 block_offsets,
2815 } => {
2816 let positions = positions.ok_or_else(|| {
2817 Error::index(format!(
2818 "positions builder missing position data for posting length {}",
2819 length
2820 ))
2821 })?;
2822 let CompressedPositionStorage::SharedStream(positions) = positions else {
2823 return Err(Error::index(
2824 "shared positions builder received legacy positions".to_owned(),
2825 ));
2826 };
2827 bytes.append_value(positions.bytes());
2828 let offsets_builder = block_offsets.values();
2829 for &offset in positions.block_offsets() {
2830 offsets_builder.append_value(offset);
2831 }
2832 block_offsets.append(true);
2833 }
2834 BatchPositionsBuilder::Legacy(position_lists) => {
2835 let positions = positions.ok_or_else(|| {
2836 Error::index(format!(
2837 "positions builder missing position data for posting length {}",
2838 length
2839 ))
2840 })?;
2841 let CompressedPositionStorage::LegacyPerDoc(positions) = positions else {
2842 return Err(Error::index(
2843 "legacy positions builder received shared position stream".to_owned(),
2844 ));
2845 };
2846 let docs_builder = position_lists.values();
2847 for doc_idx in 0..positions.len() {
2848 let doc_positions = positions.value(doc_idx);
2849 let compressed_positions = doc_positions.as_binary::<i64>();
2850 for block_idx in 0..compressed_positions.len() {
2851 docs_builder
2852 .values()
2853 .append_value(compressed_positions.value(block_idx));
2854 }
2855 docs_builder.append(true);
2856 }
2857 position_lists.append(true);
2858 }
2859 }
2860
2861 self.len += 1;
2862 Ok(())
2863 }
2864
2865 pub fn finish(&mut self) -> Result<RecordBatch> {
2866 let mut columns = vec![
2867 Arc::new(self.postings.finish()) as ArrayRef,
2868 Arc::new(self.max_scores.finish()) as ArrayRef,
2869 Arc::new(self.lengths.finish()) as ArrayRef,
2870 ];
2871 match &mut self.positions {
2872 BatchPositionsBuilder::None => {}
2873 BatchPositionsBuilder::Legacy(position_lists) => {
2874 columns.push(Arc::new(position_lists.finish()) as ArrayRef);
2875 }
2876 BatchPositionsBuilder::Shared {
2877 bytes,
2878 block_offsets,
2879 } => {
2880 columns.push(Arc::new(bytes.finish()) as ArrayRef);
2881 columns.push(Arc::new(block_offsets.finish()) as ArrayRef);
2882 }
2883 }
2884 self.len = 0;
2885 RecordBatch::try_new(self.schema.clone(), columns).map_err(Error::from)
2886 }
2887}
2888
2889impl PostingListBuilder {
2890 pub fn size(&self) -> u64 {
2891 self.memory_size_bytes as u64
2892 }
2893
2894 pub fn has_positions(&self) -> bool {
2895 self.with_positions
2896 }
2897
2898 pub fn new(with_position: bool) -> Self {
2899 Self::new_with_posting_tail_codec(
2900 with_position,
2901 current_fts_format_version().posting_tail_codec(),
2902 )
2903 }
2904
2905 pub fn new_with_posting_tail_codec(
2906 with_position: bool,
2907 posting_tail_codec: PostingTailCodec,
2908 ) -> Self {
2909 Self {
2910 with_positions: with_position,
2911 posting_tail_codec,
2912 encoded_blocks: None,
2913 encoded_position_blocks: None,
2914 tail_entries: Vec::new(),
2915 tail_positions: PositionBlockBuilder::default(),
2916 open_doc_id: None,
2917 open_doc_frequency: 0,
2918 open_doc_last_position: None,
2919 len: 0,
2920 memory_size_bytes: 0,
2921 }
2922 }
2923
2924 pub fn len(&self) -> usize {
2925 self.len as usize
2926 }
2927
2928 pub fn is_empty(&self) -> bool {
2929 self.len == 0
2930 }
2931
2932 pub fn iter(&self) -> std::vec::IntoIter<(u32, u32, Option<Vec<u32>>)> {
2933 self.collect_entries().into_iter()
2934 }
2935
2936 pub fn for_each_entry<E>(
2937 &self,
2938 mut visit: impl FnMut(u32, u32, Option<Vec<u32>>) -> std::result::Result<(), E>,
2939 ) -> std::result::Result<(), E> {
2940 let mut doc_ids = Vec::with_capacity(BLOCK_SIZE);
2941 let mut frequencies = Vec::with_capacity(BLOCK_SIZE);
2942 let mut decoded_positions = Vec::new();
2943 let mut position_block_index = 0usize;
2944
2945 if let Some(encoded_blocks) = self.encoded_blocks.as_deref() {
2946 for block in encoded_blocks.iter() {
2947 doc_ids.clear();
2948 frequencies.clear();
2949 super::encoding::decode_full_posting_block(block, &mut doc_ids, &mut frequencies);
2950 decoded_positions.clear();
2951 if self.with_positions {
2952 let position_blocks = self
2953 .encoded_position_blocks
2954 .as_deref()
2955 .expect("positions must exist for posting list");
2956 super::encoding::decode_position_stream_block(
2957 position_blocks.block(position_block_index),
2958 &frequencies,
2959 PositionStreamCodec::PackedDelta,
2960 &mut decoded_positions,
2961 )
2962 .expect("position stream decoding should succeed");
2963 position_block_index += 1;
2964 }
2965 let mut offset = 0usize;
2966 for (doc_id, frequency) in doc_ids.iter().copied().zip(frequencies.iter().copied())
2967 {
2968 let positions = self.with_positions.then(|| {
2969 let end = offset + frequency as usize;
2970 let doc_positions = decoded_positions[offset..end].to_vec();
2971 offset = end;
2972 doc_positions
2973 });
2974 visit(doc_id, frequency, positions)?;
2975 }
2976 }
2977 }
2978
2979 let mut decoded_tail_positions = Vec::new();
2980 if self.with_positions && !self.tail_entries.is_empty() {
2981 let tail_frequencies = self
2982 .tail_entries
2983 .iter()
2984 .map(|entry| entry.frequency)
2985 .collect::<Vec<_>>();
2986 self.tail_positions
2987 .decode_into(tail_frequencies.as_slice(), &mut decoded_tail_positions)
2988 .expect("tail position stream decoding should succeed");
2989 }
2990 let mut tail_offset = 0usize;
2991 for entry in &self.tail_entries {
2992 let positions = self.with_positions.then(|| {
2993 let end = tail_offset + entry.frequency as usize;
2994 let doc_positions = decoded_tail_positions[tail_offset..end].to_vec();
2995 tail_offset = end;
2996 doc_positions
2997 });
2998 visit(entry.doc_id, entry.frequency, positions)?;
2999 }
3000
3001 Ok(())
3002 }
3003
3004 pub fn add(&mut self, doc_id: u32, term_positions: PositionRecorder) {
3005 debug_assert!(
3006 self.open_doc_id.is_none(),
3007 "cannot add closed doc while a positions doc is still open"
3008 );
3009 let tail_entries_capacity_before = self.tail_entries.capacity();
3010 self.tail_entries
3011 .push(RawDocInfo::new(doc_id, term_positions.len()));
3012 let tail_entries_capacity_after = self.tail_entries.capacity();
3013 if tail_entries_capacity_after > tail_entries_capacity_before {
3014 self.add_memory_bytes(
3015 (tail_entries_capacity_after - tail_entries_capacity_before)
3016 * std::mem::size_of::<RawDocInfo>(),
3017 );
3018 }
3019 if let PositionRecorder::Position(positions_in_doc) = term_positions {
3020 debug_assert!(self.with_positions);
3021 let old_size = self.tail_positions.size();
3022 self.tail_positions
3023 .append_doc_positions(positions_in_doc.as_slice())
3024 .expect("position stream encoding should succeed");
3025 self.adjust_tail_positions_size(old_size);
3026 }
3027 self.len += 1;
3028
3029 if self.tail_entries.len() == BLOCK_SIZE {
3030 self.flush_tail_block()
3031 .expect("posting list block compression should succeed");
3032 }
3033 }
3034
3035 pub fn add_occurrence(&mut self, doc_id: u32, position: u32) -> Result<bool> {
3036 if !self.with_positions {
3037 return Err(Error::index(
3038 "cannot append streamed positions to a posting list without positions".to_owned(),
3039 ));
3040 }
3041
3042 match self.open_doc_id {
3043 Some(open_doc_id) if open_doc_id == doc_id => {
3044 let old_size = self.tail_positions.size();
3045 self.tail_positions
3046 .append_position(position, self.open_doc_last_position)?;
3047 self.adjust_tail_positions_size(old_size);
3048 self.open_doc_frequency += 1;
3049 self.open_doc_last_position = Some(position);
3050 Ok(false)
3051 }
3052 Some(open_doc_id) => Err(Error::index(format!(
3053 "posting list received doc {} before finishing open doc {}",
3054 doc_id, open_doc_id
3055 ))),
3056 None => {
3057 let old_size = self.tail_positions.size();
3058 self.tail_positions.append_position(position, None)?;
3059 self.adjust_tail_positions_size(old_size);
3060 self.open_doc_id = Some(doc_id);
3061 self.open_doc_frequency = 1;
3062 self.open_doc_last_position = Some(position);
3063 self.len += 1;
3064 Ok(true)
3065 }
3066 }
3067 }
3068
3069 pub fn finish_open_doc(&mut self, doc_id: u32) -> Result<()> {
3070 if !self.with_positions {
3071 return Ok(());
3072 }
3073 match self.open_doc_id {
3074 Some(open_doc_id) if open_doc_id == doc_id => {
3075 let tail_entries_capacity_before = self.tail_entries.capacity();
3076 self.tail_entries
3077 .push(RawDocInfo::new(doc_id, self.open_doc_frequency));
3078 let tail_entries_capacity_after = self.tail_entries.capacity();
3079 if tail_entries_capacity_after > tail_entries_capacity_before {
3080 self.add_memory_bytes(
3081 (tail_entries_capacity_after - tail_entries_capacity_before)
3082 * std::mem::size_of::<RawDocInfo>(),
3083 );
3084 }
3085 self.open_doc_id = None;
3086 self.open_doc_frequency = 0;
3087 self.open_doc_last_position = None;
3088 if self.tail_entries.len() == BLOCK_SIZE {
3089 self.flush_tail_block()?;
3090 }
3091 Ok(())
3092 }
3093 Some(open_doc_id) => Err(Error::index(format!(
3094 "attempted to finish doc {} while doc {} is still open",
3095 doc_id, open_doc_id
3096 ))),
3097 None => Ok(()),
3098 }
3099 }
3100
3101 fn collect_entries(&self) -> Vec<(u32, u32, Option<Vec<u32>>)> {
3102 let mut entries = Vec::with_capacity(self.len());
3103 self.for_each_entry(|doc_id, frequency, positions| {
3104 entries.push((doc_id, frequency, positions));
3105 Ok::<(), ()>(())
3106 })
3107 .expect("collecting posting list entries should not fail");
3108 entries
3109 }
3110
3111 fn encoded_blocks_mut(&mut self) -> &mut EncodedBlocks {
3112 if self.encoded_blocks.is_none() {
3113 self.encoded_blocks = Some(Box::default());
3114 self.add_memory_bytes(std::mem::size_of::<EncodedBlocks>());
3115 }
3116 self.encoded_blocks
3117 .as_deref_mut()
3118 .expect("encoded blocks must exist")
3119 }
3120
3121 fn encoded_position_blocks_mut(&mut self) -> &mut EncodedPositionBlocks {
3122 if self.encoded_position_blocks.is_none() {
3123 self.encoded_position_blocks = Some(Box::default());
3124 self.add_memory_bytes(std::mem::size_of::<EncodedPositionBlocks>());
3125 }
3126 self.encoded_position_blocks
3127 .as_deref_mut()
3128 .expect("encoded position blocks must exist")
3129 }
3130
3131 fn flush_tail_block(&mut self) -> Result<()> {
3132 if self.tail_entries.is_empty() {
3133 return Ok(());
3134 }
3135 debug_assert!(
3136 self.open_doc_id.is_none(),
3137 "cannot flush a posting block while a document is still open"
3138 );
3139 debug_assert_eq!(self.tail_entries.len(), BLOCK_SIZE);
3140 let mut doc_ids = [0u32; BLOCK_SIZE];
3141 let mut frequencies = [0u32; BLOCK_SIZE];
3142 for (index, entry) in self.tail_entries.iter().enumerate() {
3143 doc_ids[index] = entry.doc_id;
3144 frequencies[index] = entry.frequency;
3145 }
3146 let encoded_blocks_size_before = self
3147 .encoded_blocks
3148 .as_ref()
3149 .map(|encoded_blocks| encoded_blocks.size())
3150 .unwrap_or(0usize);
3151 self.encoded_blocks_mut()
3152 .push_full_block(&doc_ids, &frequencies)?;
3153 let encoded_blocks_size_after = self
3154 .encoded_blocks
3155 .as_ref()
3156 .map(|encoded_blocks| encoded_blocks.size())
3157 .unwrap_or(0usize);
3158 if encoded_blocks_size_after > encoded_blocks_size_before {
3159 self.add_memory_bytes(encoded_blocks_size_after - encoded_blocks_size_before);
3160 }
3161 if self.with_positions {
3162 let encoded_positions_size_before = self
3163 .encoded_position_blocks
3164 .as_ref()
3165 .map(|encoded| encoded.size())
3166 .unwrap_or(0usize);
3167 let released_tail_positions_bytes = self.tail_positions.size();
3168 let tail_position_block = std::mem::take(&mut self.tail_positions).finish();
3169 self.encoded_position_blocks_mut()
3170 .push_encoded_block(tail_position_block.as_slice());
3171 let encoded_positions_size_after = self
3172 .encoded_position_blocks
3173 .as_ref()
3174 .map(|encoded| encoded.size())
3175 .unwrap_or(0usize);
3176 if released_tail_positions_bytes > 0 {
3177 self.subtract_memory_bytes(released_tail_positions_bytes);
3178 }
3179 if encoded_positions_size_after > encoded_positions_size_before {
3180 self.add_memory_bytes(encoded_positions_size_after - encoded_positions_size_before);
3181 }
3182 }
3183 self.tail_entries.clear();
3184 Ok(())
3185 }
3186
3187 fn adjust_tail_positions_size(&mut self, old_size: usize) {
3188 let new_size = self.tail_positions.size();
3189 if new_size > old_size {
3190 self.add_memory_bytes(new_size - old_size);
3191 } else if old_size > new_size {
3192 self.subtract_memory_bytes(old_size - new_size);
3193 }
3194 }
3195
3196 fn add_memory_bytes(&mut self, bytes: usize) {
3197 self.memory_size_bytes = self
3198 .memory_size_bytes
3199 .checked_add(
3200 u32::try_from(bytes).expect("posting list memory size delta overflowed u32"),
3201 )
3202 .expect("posting list memory size overflowed u32");
3203 }
3204
3205 fn subtract_memory_bytes(&mut self, bytes: usize) {
3206 self.memory_size_bytes = self
3207 .memory_size_bytes
3208 .checked_sub(
3209 u32::try_from(bytes).expect("posting list memory size delta overflowed u32"),
3210 )
3211 .expect("posting list memory size underflowed u32");
3212 }
3213
3214 fn build_position_columns(
3215 positions: Option<CompressedPositionStorage>,
3216 ) -> Result<Vec<ArrayRef>> {
3217 let Some(positions) = positions else {
3218 return Ok(Vec::new());
3219 };
3220 match positions {
3221 CompressedPositionStorage::LegacyPerDoc(positions) => {
3222 Ok(vec![Arc::new(ListArray::try_new(
3223 Arc::new(Field::new("item", positions.data_type().clone(), true)),
3224 OffsetBuffer::new(ScalarBuffer::from(vec![0_i32, positions.len() as i32])),
3225 Arc::new(positions) as ArrayRef,
3226 None,
3227 )?) as ArrayRef])
3228 }
3229 CompressedPositionStorage::SharedStream(positions) => {
3230 let mut columns = Vec::with_capacity(2);
3231 columns.push(
3232 Arc::new(LargeBinaryArray::from(vec![Some(positions.bytes())])) as ArrayRef,
3233 );
3234
3235 let mut offsets_builder = ListBuilder::new(UInt32Builder::new());
3236 for &offset in positions.block_offsets() {
3237 offsets_builder.values().append_value(offset);
3238 }
3239 offsets_builder.append(true);
3240 columns.push(Arc::new(offsets_builder.finish()) as ArrayRef);
3241 Ok(columns)
3242 }
3243 }
3244 }
3245
3246 fn build_batch(
3247 self,
3248 compressed: LargeBinaryArray,
3249 max_score: f32,
3250 schema: SchemaRef,
3251 positions: Option<CompressedPositionStorage>,
3252 ) -> Result<RecordBatch> {
3253 let length = self.len();
3254 let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, compressed.len() as i32]));
3255 let mut columns = vec![
3256 Arc::new(ListArray::try_new(
3257 Arc::new(Field::new("item", datatypes::DataType::LargeBinary, true)),
3258 offsets,
3259 Arc::new(compressed),
3260 None,
3261 )?) as ArrayRef,
3262 Arc::new(Float32Array::from_iter_values(std::iter::once(max_score))) as ArrayRef,
3263 Arc::new(UInt32Array::from_iter_values(std::iter::once(
3264 length as u32,
3265 ))) as ArrayRef,
3266 ];
3267 columns.extend(Self::build_position_columns(positions)?);
3268
3269 let batch = RecordBatch::try_new(schema, columns)?;
3270 Ok(batch)
3271 }
3272
3273 fn build_legacy_positions(&self) -> Result<ListArray> {
3274 let mut positions_builder = ListBuilder::new(LargeBinaryBuilder::new());
3275 self.for_each_entry(|_doc_id, frequency, positions| {
3276 let positions = positions.ok_or_else(|| {
3277 Error::index(format!(
3278 "legacy position writer missing positions for frequency {}",
3279 frequency
3280 ))
3281 })?;
3282 let compressed = super::encoding::compress_positions(positions.as_slice())?;
3283 for block_idx in 0..compressed.len() {
3284 positions_builder
3285 .values()
3286 .append_value(compressed.value(block_idx));
3287 }
3288 positions_builder.append(true);
3289 Ok::<(), Error>(())
3290 })?;
3291 Ok(positions_builder.finish())
3292 }
3293
3294 pub(super) fn append_to_batch_with_docs(
3295 self,
3296 docs: &DocSet,
3297 batch_builder: &mut PostingListBatchBuilder,
3298 format_version: InvertedListFormatVersion,
3299 ) -> Result<()> {
3300 let legacy_positions =
3301 if self.with_positions && !format_version.uses_shared_position_stream() {
3302 Some(self.build_legacy_positions()?)
3303 } else {
3304 None
3305 };
3306 let Self {
3307 with_positions,
3308 posting_tail_codec,
3309 encoded_blocks,
3310 encoded_position_blocks,
3311 tail_entries,
3312 tail_positions,
3313 open_doc_id,
3314 open_doc_frequency,
3315 open_doc_last_position,
3316 len,
3317 ..
3318 } = self;
3319 debug_assert!(open_doc_id.is_none());
3320 debug_assert_eq!(open_doc_frequency, 0);
3321 debug_assert!(open_doc_last_position.is_none());
3322 let parts = PostingListParts {
3323 with_positions,
3324 posting_tail_codec,
3325 length: len as usize,
3326 encoded_blocks: encoded_blocks
3327 .map(|encoded_blocks| *encoded_blocks)
3328 .unwrap_or_default(),
3329 encoded_position_blocks: encoded_position_blocks
3330 .map(|encoded_positions| *encoded_positions)
3331 .unwrap_or_default(),
3332 tail_entries: tail_entries.as_slice(),
3333 tail_position_block: with_positions.then(|| tail_positions.finish()),
3334 };
3335 let (compressed, shared_positions, max_score) =
3336 Self::build_compressed_with_scores_from_parts(parts, docs)?;
3337 let positions = match legacy_positions {
3338 Some(positions) => Some(CompressedPositionStorage::LegacyPerDoc(positions)),
3339 None => shared_positions.map(CompressedPositionStorage::SharedStream),
3340 };
3341 batch_builder.append(compressed, max_score, len, positions.as_ref())
3342 }
3343
3344 fn extend_tail_components(
3345 tail_entries: &[RawDocInfo],
3346 doc_ids: &mut Vec<u32>,
3347 frequencies: &mut Vec<u32>,
3348 ) {
3349 doc_ids.clear();
3350 frequencies.clear();
3351 doc_ids.extend(tail_entries.iter().map(|entry| entry.doc_id));
3352 frequencies.extend(tail_entries.iter().map(|entry| entry.frequency));
3353 }
3354
3355 fn build_compressed_with_scores_from_parts(
3356 parts: PostingListParts<'_>,
3357 docs: &DocSet,
3358 ) -> Result<(LargeBinaryArray, Option<SharedPositionStream>, f32)> {
3359 let PostingListParts {
3360 with_positions,
3361 posting_tail_codec,
3362 length,
3363 mut encoded_blocks,
3364 mut encoded_position_blocks,
3365 tail_entries,
3366 tail_position_block,
3367 } = parts;
3368 let avgdl = docs.average_length();
3369 let idf_scale = idf(length, docs.len()) * (K1 + 1.0);
3370 let mut max_score = f32::MIN;
3371 let mut doc_ids = Vec::with_capacity(BLOCK_SIZE);
3372 let mut frequencies = Vec::with_capacity(BLOCK_SIZE);
3373
3374 for index in 0..encoded_blocks.len() {
3375 let block = encoded_blocks.block(index);
3376 doc_ids.clear();
3377 frequencies.clear();
3378 super::encoding::decode_full_posting_block(block, &mut doc_ids, &mut frequencies);
3379 let block_score = compute_block_score(
3380 docs,
3381 avgdl,
3382 idf_scale,
3383 doc_ids.iter().copied(),
3384 frequencies.iter().copied(),
3385 );
3386 max_score = max_score.max(block_score);
3387 encoded_blocks.set_block_score(index, block_score);
3388 }
3389
3390 if !tail_entries.is_empty() {
3391 Self::extend_tail_components(tail_entries, &mut doc_ids, &mut frequencies);
3392 let block_score = compute_block_score(
3393 docs,
3394 avgdl,
3395 idf_scale,
3396 doc_ids.iter().copied(),
3397 frequencies.iter().copied(),
3398 );
3399 max_score = max_score.max(block_score);
3400 encoded_blocks.append_remainder_block_with_codec(
3401 doc_ids.as_slice(),
3402 frequencies.as_slice(),
3403 posting_tail_codec,
3404 )?;
3405 encoded_blocks.set_block_score(encoded_blocks.len() - 1, block_score);
3406 if with_positions {
3407 encoded_position_blocks.push_encoded_block(
3408 tail_position_block
3409 .as_deref()
3410 .expect("tail position block must exist for postings with positions"),
3411 );
3412 }
3413 }
3414
3415 Ok((
3416 encoded_blocks.into_array(),
3417 with_positions.then(|| encoded_position_blocks.into_stream()),
3418 max_score,
3419 ))
3420 }
3421
3422 fn build_compressed_with_block_scores_from_parts(
3423 with_positions: bool,
3424 posting_tail_codec: PostingTailCodec,
3425 mut encoded_blocks: EncodedBlocks,
3426 mut encoded_position_blocks: EncodedPositionBlocks,
3427 tail_entries: &[RawDocInfo],
3428 tail_position_block: Option<Vec<u8>>,
3429 mut block_max_scores: impl Iterator<Item = f32>,
3430 ) -> Result<(LargeBinaryArray, Option<SharedPositionStream>, f32)> {
3431 let mut max_score = f32::MIN;
3432 let mut doc_ids = Vec::with_capacity(BLOCK_SIZE);
3433 let mut frequencies = Vec::with_capacity(BLOCK_SIZE);
3434
3435 for index in 0..encoded_blocks.len() {
3436 let block_score = block_max_scores
3437 .next()
3438 .ok_or_else(|| Error::index("missing block max score".to_owned()))?;
3439 max_score = max_score.max(block_score);
3440 encoded_blocks.set_block_score(index, block_score);
3441 }
3442
3443 if !tail_entries.is_empty() {
3444 let block_score = block_max_scores
3445 .next()
3446 .ok_or_else(|| Error::index("missing tail block max score".to_owned()))?;
3447 max_score = max_score.max(block_score);
3448 Self::extend_tail_components(tail_entries, &mut doc_ids, &mut frequencies);
3449 encoded_blocks.append_remainder_block_with_codec(
3450 doc_ids.as_slice(),
3451 frequencies.as_slice(),
3452 posting_tail_codec,
3453 )?;
3454 encoded_blocks.set_block_score(encoded_blocks.len() - 1, block_score);
3455 if with_positions {
3456 encoded_position_blocks.push_encoded_block(
3457 tail_position_block
3458 .as_deref()
3459 .expect("tail position block must exist for postings with positions"),
3460 );
3461 }
3462 }
3463
3464 Ok((
3465 encoded_blocks.into_array(),
3466 with_positions.then(|| encoded_position_blocks.into_stream()),
3467 max_score,
3468 ))
3469 }
3470
3471 pub fn to_batch(self, block_max_scores: Vec<f32>) -> Result<RecordBatch> {
3472 let format_version = if self.posting_tail_codec == PostingTailCodec::Fixed32 {
3473 InvertedListFormatVersion::V1
3474 } else {
3475 InvertedListFormatVersion::V2
3476 };
3477 let schema = inverted_list_schema_for_version(self.has_positions(), format_version);
3478 let legacy_positions =
3479 if self.with_positions && !format_version.uses_shared_position_stream() {
3480 Some(self.build_legacy_positions()?)
3481 } else {
3482 None
3483 };
3484 let Self {
3485 with_positions,
3486 posting_tail_codec,
3487 encoded_blocks,
3488 encoded_position_blocks,
3489 tail_entries,
3490 tail_positions,
3491 open_doc_id,
3492 open_doc_frequency,
3493 open_doc_last_position,
3494 len,
3495 ..
3496 } = self;
3497 debug_assert!(open_doc_id.is_none());
3498 debug_assert_eq!(open_doc_frequency, 0);
3499 debug_assert!(open_doc_last_position.is_none());
3500 let (compressed, shared_positions, max_score) =
3501 Self::build_compressed_with_block_scores_from_parts(
3502 with_positions,
3503 posting_tail_codec,
3504 encoded_blocks
3505 .map(|encoded_blocks| *encoded_blocks)
3506 .unwrap_or_default(),
3507 encoded_position_blocks
3508 .map(|encoded_positions| *encoded_positions)
3509 .unwrap_or_default(),
3510 tail_entries.as_slice(),
3511 with_positions.then(|| tail_positions.finish()),
3512 block_max_scores.into_iter(),
3513 )?;
3514 let builder = Self {
3515 with_positions,
3516 posting_tail_codec,
3517 encoded_blocks: None,
3518 encoded_position_blocks: None,
3519 tail_entries: Vec::new(),
3520 tail_positions: PositionBlockBuilder::default(),
3521 open_doc_id: None,
3522 open_doc_frequency: 0,
3523 open_doc_last_position: None,
3524 memory_size_bytes: 0,
3525 len,
3526 };
3527 let positions = match legacy_positions {
3528 Some(positions) => Some(CompressedPositionStorage::LegacyPerDoc(positions)),
3529 None => shared_positions.map(CompressedPositionStorage::SharedStream),
3530 };
3531 builder.build_batch(compressed, max_score, schema, positions)
3532 }
3533
3534 pub fn to_batch_with_docs(self, docs: &DocSet, schema: SchemaRef) -> Result<RecordBatch> {
3535 let format_version = if schema.column_with_name(POSITION_COL).is_some()
3536 && schema.column_with_name(COMPRESSED_POSITION_COL).is_none()
3537 {
3538 InvertedListFormatVersion::V1
3539 } else {
3540 InvertedListFormatVersion::V2
3541 };
3542 let legacy_positions =
3543 if self.with_positions && !format_version.uses_shared_position_stream() {
3544 Some(self.build_legacy_positions()?)
3545 } else {
3546 None
3547 };
3548 let Self {
3549 with_positions,
3550 posting_tail_codec,
3551 encoded_blocks,
3552 encoded_position_blocks,
3553 tail_entries,
3554 tail_positions,
3555 open_doc_id,
3556 open_doc_frequency,
3557 open_doc_last_position,
3558 len,
3559 ..
3560 } = self;
3561 debug_assert!(open_doc_id.is_none());
3562 debug_assert_eq!(open_doc_frequency, 0);
3563 debug_assert!(open_doc_last_position.is_none());
3564 let parts = PostingListParts {
3565 with_positions,
3566 posting_tail_codec,
3567 length: len as usize,
3568 encoded_blocks: encoded_blocks
3569 .map(|encoded_blocks| *encoded_blocks)
3570 .unwrap_or_default(),
3571 encoded_position_blocks: encoded_position_blocks
3572 .map(|encoded_positions| *encoded_positions)
3573 .unwrap_or_default(),
3574 tail_entries: tail_entries.as_slice(),
3575 tail_position_block: with_positions.then(|| tail_positions.finish()),
3576 };
3577 let (compressed, shared_positions, max_score) =
3578 Self::build_compressed_with_scores_from_parts(parts, docs)?;
3579 let builder = Self {
3580 with_positions,
3581 posting_tail_codec,
3582 encoded_blocks: None,
3583 encoded_position_blocks: None,
3584 tail_entries: Vec::new(),
3585 tail_positions: PositionBlockBuilder::default(),
3586 open_doc_id: None,
3587 open_doc_frequency: 0,
3588 open_doc_last_position: None,
3589 memory_size_bytes: 0,
3590 len,
3591 };
3592 let positions = match legacy_positions {
3593 Some(positions) => Some(CompressedPositionStorage::LegacyPerDoc(positions)),
3594 None => shared_positions.map(CompressedPositionStorage::SharedStream),
3595 };
3596 builder.build_batch(compressed, max_score, schema, positions)
3597 }
3598
3599 pub fn remap(&mut self, removed: &[u32]) {
3600 let mut cursor = 0;
3601 let mut new_builder =
3602 Self::new_with_posting_tail_codec(self.has_positions(), self.posting_tail_codec);
3603 for (doc_id, freq, positions) in self.iter() {
3604 while cursor < removed.len() && removed[cursor] < doc_id {
3605 cursor += 1;
3606 }
3607 if cursor < removed.len() && removed[cursor] == doc_id {
3608 continue;
3609 }
3610 let positions = match positions {
3611 Some(positions) => PositionRecorder::Position(positions.into()),
3612 None => PositionRecorder::Count(freq),
3613 };
3614 new_builder.add(doc_id - cursor as u32, positions);
3615 }
3616
3617 *self = new_builder;
3618 }
3619}
3620
3621fn compute_block_score(
3622 docs: &DocSet,
3623 avgdl: f32,
3624 idf_scale: f32,
3625 doc_ids: impl Iterator<Item = u32>,
3626 frequencies: impl Iterator<Item = u32>,
3627) -> f32 {
3628 let mut block_max_score = f32::MIN;
3629 for (doc_id, freq) in doc_ids.zip(frequencies) {
3630 let doc_norm = K1 * (1.0 - B + B * docs.num_tokens(doc_id) as f32 / avgdl);
3631 let freq = freq as f32;
3632 let score = freq / (freq + doc_norm);
3633 block_max_score = block_max_score.max(score);
3634 }
3635 block_max_score * idf_scale
3636}
3637
3638#[derive(Debug, Clone, DeepSizeOf, Copy)]
3639pub enum DocInfo {
3640 Located(LocatedDocInfo),
3641 Raw(RawDocInfo),
3642}
3643
3644impl DocInfo {
3645 pub fn doc_id(&self) -> u64 {
3646 match self {
3647 Self::Raw(info) => info.doc_id as u64,
3648 Self::Located(info) => info.row_id,
3649 }
3650 }
3651
3652 pub fn frequency(&self) -> u32 {
3653 match self {
3654 Self::Raw(info) => info.frequency,
3655 Self::Located(info) => info.frequency as u32,
3656 }
3657 }
3658}
3659
3660impl Eq for DocInfo {}
3661
3662impl PartialEq for DocInfo {
3663 fn eq(&self, other: &Self) -> bool {
3664 self.doc_id() == other.doc_id()
3665 }
3666}
3667
3668impl PartialOrd for DocInfo {
3669 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
3670 Some(self.cmp(other))
3671 }
3672}
3673
3674impl Ord for DocInfo {
3675 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
3676 self.doc_id().cmp(&other.doc_id())
3677 }
3678}
3679
3680#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
3681pub struct LocatedDocInfo {
3682 pub row_id: u64,
3683 pub frequency: f32,
3684}
3685
3686impl LocatedDocInfo {
3687 pub fn new(row_id: u64, frequency: f32) -> Self {
3688 Self { row_id, frequency }
3689 }
3690}
3691
3692impl Eq for LocatedDocInfo {}
3693
3694impl PartialEq for LocatedDocInfo {
3695 fn eq(&self, other: &Self) -> bool {
3696 self.row_id == other.row_id
3697 }
3698}
3699
3700impl PartialOrd for LocatedDocInfo {
3701 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
3702 Some(self.cmp(other))
3703 }
3704}
3705
3706impl Ord for LocatedDocInfo {
3707 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
3708 self.row_id.cmp(&other.row_id)
3709 }
3710}
3711
3712#[derive(Debug, Clone, Default, DeepSizeOf, Copy)]
3713pub struct RawDocInfo {
3714 pub doc_id: u32,
3715 pub frequency: u32,
3716}
3717
3718impl RawDocInfo {
3719 pub fn new(doc_id: u32, frequency: u32) -> Self {
3720 Self { doc_id, frequency }
3721 }
3722}
3723
3724impl Eq for RawDocInfo {}
3725
3726impl PartialEq for RawDocInfo {
3727 fn eq(&self, other: &Self) -> bool {
3728 self.doc_id == other.doc_id
3729 }
3730}
3731
3732impl PartialOrd for RawDocInfo {
3733 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
3734 Some(self.cmp(other))
3735 }
3736}
3737
3738impl Ord for RawDocInfo {
3739 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
3740 self.doc_id.cmp(&other.doc_id)
3741 }
3742}
3743
3744#[derive(Debug, Clone, Default, DeepSizeOf)]
3747pub struct DocSet {
3748 row_ids: Vec<u64>,
3749 num_tokens: Vec<u32>,
3750 inv: Vec<(u64, u32)>,
3752
3753 total_tokens: u64,
3754}
3755
3756impl DocSet {
3757 #[inline]
3758 pub fn len(&self) -> usize {
3759 self.row_ids.len()
3760 }
3761
3762 pub fn is_empty(&self) -> bool {
3763 self.len() == 0
3764 }
3765
3766 pub fn iter(&self) -> impl Iterator<Item = (&u64, &u32)> {
3767 self.row_ids.iter().zip(self.num_tokens.iter())
3768 }
3769
3770 pub fn row_id(&self, doc_id: u32) -> u64 {
3771 self.row_ids[doc_id as usize]
3772 }
3773
3774 pub fn doc_id(&self, row_id: u64) -> Option<u64> {
3775 if self.inv.is_empty() {
3776 match self.row_ids.binary_search(&row_id) {
3778 Ok(_) => Some(row_id),
3779 Err(_) => None,
3780 }
3781 } else {
3782 match self.inv.binary_search_by_key(&row_id, |x| x.0) {
3783 Ok(idx) => Some(self.inv[idx].1 as u64),
3784 Err(_) => None,
3785 }
3786 }
3787 }
3788 pub fn total_tokens_num(&self) -> u64 {
3789 self.total_tokens
3790 }
3791
3792 #[inline]
3793 pub fn average_length(&self) -> f32 {
3794 self.total_tokens as f32 / self.len() as f32
3795 }
3796
3797 pub fn calculate_block_max_scores<'a>(
3798 &self,
3799 doc_ids: impl Iterator<Item = &'a u32>,
3800 freqs: impl Iterator<Item = &'a u32>,
3801 ) -> Vec<f32> {
3802 let avgdl = self.average_length();
3803 let length = doc_ids.size_hint().0;
3804 let num_blocks = length.div_ceil(BLOCK_SIZE);
3805 let mut block_max_scores = Vec::with_capacity(num_blocks);
3806 let idf_scale = idf(length, self.len()) * (K1 + 1.0);
3807 let mut max_score = f32::MIN;
3808 for (i, (doc_id, freq)) in doc_ids.zip(freqs).enumerate() {
3809 let doc_norm = K1 * (1.0 - B + B * self.num_tokens(*doc_id) as f32 / avgdl);
3810 let freq = *freq as f32;
3811 let score = freq / (freq + doc_norm);
3812 if score > max_score {
3813 max_score = score;
3814 }
3815 if (i + 1) % BLOCK_SIZE == 0 {
3816 max_score *= idf_scale;
3817 block_max_scores.push(max_score);
3818 max_score = f32::MIN;
3819 }
3820 }
3821 if !length.is_multiple_of(BLOCK_SIZE) {
3822 max_score *= idf_scale;
3823 block_max_scores.push(max_score);
3824 }
3825 block_max_scores
3826 }
3827
3828 pub fn to_batch(&self) -> Result<RecordBatch> {
3829 let row_id_col = UInt64Array::from_iter_values(self.row_ids.iter().cloned());
3830 let num_tokens_col = UInt32Array::from_iter_values(self.num_tokens.iter().cloned());
3831
3832 let schema = arrow_schema::Schema::new(vec![
3833 arrow_schema::Field::new(ROW_ID, DataType::UInt64, false),
3834 arrow_schema::Field::new(NUM_TOKEN_COL, DataType::UInt32, false),
3835 ]);
3836
3837 let batch = RecordBatch::try_new(
3838 Arc::new(schema),
3839 vec![
3840 Arc::new(row_id_col) as ArrayRef,
3841 Arc::new(num_tokens_col) as ArrayRef,
3842 ],
3843 )?;
3844 Ok(batch)
3845 }
3846
3847 pub async fn load(
3848 reader: Arc<dyn IndexReader>,
3849 is_legacy: bool,
3850 frag_reuse_index: Option<Arc<FragReuseIndex>>,
3851 ) -> Result<Self> {
3852 let batch = reader.read_range(0..reader.num_rows(), None).await?;
3853 let row_id_col = batch[ROW_ID].as_primitive::<datatypes::UInt64Type>();
3854 let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::<datatypes::UInt32Type>();
3855
3856 if is_legacy {
3858 let (row_ids, num_tokens): (Vec<_>, Vec<_>) = row_id_col
3859 .values()
3860 .iter()
3861 .filter_map(|id| {
3862 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
3863 frag_reuse_index_ref.remap_row_id(*id)
3864 } else {
3865 Some(*id)
3866 }
3867 })
3868 .zip(num_tokens_col.values().iter())
3869 .sorted_unstable_by_key(|x| x.0)
3870 .unzip();
3871
3872 let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
3873 return Ok(Self {
3874 row_ids,
3875 num_tokens,
3876 inv: Vec::new(),
3877 total_tokens,
3878 });
3879 }
3880
3881 if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
3884 let mut row_ids = Vec::with_capacity(row_id_col.len());
3885 let mut num_tokens = Vec::with_capacity(num_tokens_col.len());
3886 for (row_id, num_token) in row_id_col.values().iter().zip(num_tokens_col.values()) {
3887 if let Some(new_row_id) = frag_reuse_index_ref.remap_row_id(*row_id) {
3888 row_ids.push(new_row_id);
3889 num_tokens.push(*num_token);
3890 }
3891 }
3892
3893 let mut inv: Vec<(u64, u32)> = row_ids
3894 .iter()
3895 .enumerate()
3896 .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
3897 .collect();
3898 inv.sort_unstable_by_key(|entry| entry.0);
3899
3900 let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
3901 return Ok(Self {
3902 row_ids,
3903 num_tokens,
3904 inv,
3905 total_tokens,
3906 });
3907 }
3908
3909 let row_ids = row_id_col.values().to_vec();
3910 let num_tokens = num_tokens_col.values().to_vec();
3911 let mut inv: Vec<(u64, u32)> = row_ids
3912 .iter()
3913 .enumerate()
3914 .map(|(doc_id, row_id)| (*row_id, doc_id as u32))
3915 .collect();
3916 if !row_ids.is_sorted() {
3917 inv.sort_unstable_by_key(|entry| entry.0);
3918 }
3919 let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
3920 Ok(Self {
3921 row_ids,
3922 num_tokens,
3923 inv,
3924 total_tokens,
3925 })
3926 }
3927
3928 pub fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Vec<u32> {
3931 let mut removed = Vec::new();
3932 let len = self.len();
3933 let row_ids = std::mem::replace(&mut self.row_ids, Vec::with_capacity(len));
3934 let num_tokens = std::mem::replace(&mut self.num_tokens, Vec::with_capacity(len));
3935 self.total_tokens = 0;
3936 for (doc_id, (row_id, num_token)) in std::iter::zip(row_ids, num_tokens).enumerate() {
3937 match mapping.get(&row_id) {
3938 Some(Some(new_row_id)) => {
3939 self.row_ids.push(*new_row_id);
3940 self.num_tokens.push(num_token);
3941 self.total_tokens += num_token as u64;
3942 }
3943 Some(None) => {
3944 removed.push(doc_id as u32);
3945 }
3946 None => {
3947 self.row_ids.push(row_id);
3948 self.num_tokens.push(num_token);
3949 self.total_tokens += num_token as u64;
3950 }
3951 }
3952 }
3953 removed
3954 }
3955
3956 #[inline]
3957 pub fn num_tokens(&self, doc_id: u32) -> u32 {
3958 self.num_tokens[doc_id as usize]
3959 }
3960
3961 #[inline]
3964 pub fn num_tokens_by_row_id(&self, row_id: u64) -> u32 {
3965 self.row_ids
3966 .binary_search(&row_id)
3967 .map(|idx| self.num_tokens[idx])
3968 .unwrap_or(0)
3969 }
3970
3971 pub fn append(&mut self, row_id: u64, num_tokens: u32) -> u32 {
3974 self.row_ids.push(row_id);
3975 self.num_tokens.push(num_tokens);
3976 self.total_tokens += num_tokens as u64;
3977 self.row_ids.len() as u32 - 1
3978 }
3979
3980 pub(crate) fn memory_size(&self) -> usize {
3981 self.row_ids.capacity() * std::mem::size_of::<u64>()
3982 + self.num_tokens.capacity() * std::mem::size_of::<u32>()
3983 + self.inv.capacity() * std::mem::size_of::<(u64, u32)>()
3984 }
3985}
3986
3987pub fn flat_full_text_search(
3988 batches: &[&RecordBatch],
3989 doc_col: &str,
3990 query: &str,
3991 tokenizer: Option<Box<dyn LanceTokenizer>>,
3992) -> Result<Vec<u64>> {
3993 if batches.is_empty() {
3994 return Ok(vec![]);
3995 }
3996
3997 if is_phrase_query(query) {
3998 return Err(Error::invalid_input(
3999 "phrase query is not supported for flat full text search, try using FTS index",
4000 ));
4001 }
4002
4003 match batches[0][doc_col].data_type() {
4004 DataType::Utf8 => do_flat_full_text_search::<i32>(batches, doc_col, query, tokenizer),
4005 DataType::LargeUtf8 => do_flat_full_text_search::<i64>(batches, doc_col, query, tokenizer),
4006 data_type => Err(Error::invalid_input(format!(
4007 "unsupported data type {} for inverted index",
4008 data_type
4009 ))),
4010 }
4011}
4012
4013fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
4014 batches: &[&RecordBatch],
4015 doc_col: &str,
4016 query: &str,
4017 tokenizer: Option<Box<dyn LanceTokenizer>>,
4018) -> Result<Vec<u64>> {
4019 let mut results = Vec::new();
4020 let mut tokenizer =
4021 tokenizer.unwrap_or_else(|| InvertedIndexParams::default().build().unwrap());
4022 let query_tokens = collect_query_tokens(query, &mut tokenizer);
4023
4024 for batch in batches {
4025 let row_id_array = batch[ROW_ID].as_primitive::<UInt64Type>();
4026 let doc_array = batch[doc_col].as_string::<Offset>();
4027 for i in 0..row_id_array.len() {
4028 let doc = doc_array.value(i);
4029 if has_query_token(doc, &mut tokenizer, &query_tokens) {
4030 results.push(row_id_array.value(i));
4031 assert!(doc.contains(query));
4034 }
4035 }
4036 }
4037
4038 Ok(results)
4039}
4040
4041const FLAT_ROW_ID_COL_IDX: usize = 0;
4042const FLAT_ALL_TOKENS_COL_IDX: usize = 1;
4043const FLAT_QUERY_TOKEN_COUNTS_COL_IDX: usize = 2;
4044
4045const BYTES_ACCUMULATED_WARNING_THRESHOLD: u64 = 1024 * 1024 * 1024; async fn tokenize_and_count(
4062 input: impl Stream<Item = DataFusionResult<RecordBatch>> + Send,
4063 tokenizer: Box<dyn LanceTokenizer>,
4064 query_tokens: Arc<Tokens>,
4065 doc_col_idx: usize,
4066 elapsed_compute: Option<Time>,
4067) -> DataFusionResult<RecordBatch> {
4068 let output_schema = Arc::new(Schema::new(vec![
4069 ROW_ID_FIELD.clone(),
4070 Field::new("all_tokens", DataType::UInt64, false),
4071 Field::new(
4072 "query_token_counts",
4073 DataType::FixedSizeList(
4074 Arc::new(Field::new("item", DataType::UInt64, true)),
4075 query_tokens.len() as i32,
4076 ),
4077 false,
4078 ),
4079 ]));
4080 let output_schema_clone = output_schema.clone();
4081 let bytes_accumulated = Arc::new(AtomicU64::new(0));
4082 let bytes_warning_emitted = Arc::new(AtomicBool::new(false));
4083
4084 let batches = input
4085 .map(move |batch| {
4086 let mut tokenizer = tokenizer.box_clone();
4087 let output_schema = output_schema.clone();
4088 let query_tokens = query_tokens.clone();
4089 let bytes_accumulated = bytes_accumulated.clone();
4090 let bytes_warning_emitted = bytes_warning_emitted.clone();
4091 let elapsed_compute = elapsed_compute.clone();
4092 spawn_cpu(move || {
4093 let start = std::time::Instant::now();
4097 let batch = batch?;
4098 let mut all_token_counts = UInt64Builder::with_capacity(batch.num_rows());
4099 let mut query_token_counts = FixedSizeListBuilder::with_capacity(
4100 UInt64Builder::with_capacity(batch.num_rows() * query_tokens.len()),
4101 query_tokens.len() as i32,
4102 batch.num_rows(),
4103 );
4104 let mut temp_query_token_counts = Vec::with_capacity(query_tokens.len());
4105 let doc_iter = iter_str_array(batch.column(doc_col_idx));
4106 for doc in doc_iter {
4107 let Some(doc) = doc else {
4108 all_token_counts.append_value(0);
4109 query_token_counts
4110 .values()
4111 .append_value_n(0, query_tokens.len());
4112 query_token_counts.append(true);
4113 continue;
4114 };
4115
4116 temp_query_token_counts.clear();
4117 temp_query_token_counts.extend(std::iter::repeat_n(0, query_tokens.len()));
4118
4119 let mut stream = tokenizer.token_stream_for_doc(doc);
4120 let mut all_tokens = 0;
4121 while let Some(token) = stream.next() {
4122 all_tokens += 1;
4123 if let Some(token_index) = query_tokens.token_index(&token.text) {
4124 temp_query_token_counts[token_index] += 1;
4125 }
4126 }
4127 all_token_counts.append_value(all_tokens);
4128 for count in temp_query_token_counts.iter().copied() {
4129 query_token_counts.values().append_value(count);
4130 }
4131 query_token_counts.append(true);
4132 }
4133 let row_ids = batch[ROW_ID].clone();
4134 let all_token_counts = all_token_counts.finish();
4135 let query_token_counts = query_token_counts.finish();
4136 let result_batch = RecordBatch::try_new(
4137
4138 output_schema,
4139 vec![
4140 row_ids,
4141 Arc::new(all_token_counts) as ArrayRef,
4142 Arc::new(query_token_counts) as ArrayRef,
4143 ],
4144 )?;
4145 let bytes_accumulated = bytes_accumulated.fetch_add(result_batch.get_array_memory_size() as u64, Ordering::Relaxed);
4146 if bytes_accumulated > BYTES_ACCUMULATED_WARNING_THRESHOLD && !bytes_warning_emitted.swap(true, Ordering::Relaxed) {
4147 tracing::warn!("Flat full text search is accumulating a large number of bytes. Consider using an FTS index instead.");
4148 }
4149
4150 if let Some(t) = &elapsed_compute {
4151 t.add_duration(start.elapsed());
4152 }
4153 DataFusionResult::Ok(result_batch)
4154 })
4155 })
4156 .buffered(get_num_compute_intensive_cpus())
4157 .try_collect::<Vec<_>>()
4158 .await?;
4159
4160 Ok(arrow::compute::concat_batches(
4161 &output_schema_clone,
4162 &batches,
4163 )?)
4164}
4165
4166fn initialize_scorer(
4171 base_scorer: Option<&MemBM25Scorer>,
4172 query_tokens: &Tokens,
4173 counted_input: &RecordBatch,
4174) -> MemBM25Scorer {
4175 let mut total_tokens = 0;
4176 let mut num_docs = 0;
4177 let mut all_token_counts = vec![0; query_tokens.len()];
4178
4179 if let Some(base_scorer) = base_scorer {
4180 total_tokens += base_scorer.total_tokens;
4181 num_docs += base_scorer.num_docs;
4182 for (token_index, token) in query_tokens.into_iter().enumerate() {
4183 all_token_counts[token_index] = base_scorer.num_docs_containing_token(token) as u64;
4184 }
4185 }
4186
4187 num_docs += counted_input.num_rows();
4188 total_tokens += arrow::compute::sum(
4189 counted_input
4190 .column(FLAT_ALL_TOKENS_COL_IDX)
4191 .as_primitive::<UInt64Type>(),
4192 )
4193 .unwrap_or_default();
4194
4195 let mut input_token_counters = counted_input
4196 .column(FLAT_QUERY_TOKEN_COUNTS_COL_IDX)
4197 .as_fixed_size_list()
4198 .values()
4199 .as_primitive::<UInt64Type>()
4200 .values()
4201 .iter()
4202 .copied();
4203
4204 for _ in 0..counted_input.num_rows() {
4205 for token_count in all_token_counts.iter_mut() {
4206 *token_count += input_token_counters.next().unwrap_or_default();
4207 }
4208 }
4209
4210 let token_counts_map = all_token_counts
4211 .into_iter()
4212 .enumerate()
4213 .map(|(token_index, count)| {
4214 (
4215 query_tokens.get_token(token_index).to_string(),
4216 count as usize,
4217 )
4218 })
4219 .collect::<HashMap<String, usize>>();
4220 MemBM25Scorer::new(total_tokens, num_docs, token_counts_map)
4221}
4222
4223fn flat_bm25_score(
4224 query_tokens: &Tokens,
4225 counted_input: &RecordBatch,
4226 scorer: &MemBM25Scorer,
4227) -> Result<RecordBatch> {
4228 let mut row_ids_builder = UInt64Builder::with_capacity(counted_input.num_rows());
4229 let mut scores_builder = Float32Builder::with_capacity(counted_input.num_rows());
4230
4231 let mut row_ids_iter = counted_input
4232 .column(FLAT_ROW_ID_COL_IDX)
4233 .as_primitive::<UInt64Type>()
4234 .values()
4235 .iter()
4236 .copied();
4237 let mut all_token_counts_iter = counted_input
4238 .column(FLAT_ALL_TOKENS_COL_IDX)
4239 .as_primitive::<UInt64Type>()
4240 .values()
4241 .iter()
4242 .copied();
4243 let mut query_token_counts_iter = counted_input
4244 .column(FLAT_QUERY_TOKEN_COUNTS_COL_IDX)
4245 .as_fixed_size_list()
4246 .values()
4247 .as_primitive::<UInt64Type>()
4248 .values()
4249 .iter()
4250 .copied();
4251 for _ in 0..counted_input.num_rows() {
4252 let num_tokens_in_doc = all_token_counts_iter.next().expect_ok()?;
4253 let row_id = row_ids_iter.next().expect_ok()?;
4254 if num_tokens_in_doc == 0 {
4255 for _ in query_tokens {
4256 query_token_counts_iter.next().expect_ok()?;
4257 }
4258 continue;
4259 }
4260 let doc_norm = K1 * (1.0 - B + B * num_tokens_in_doc as f32 / scorer.avg_doc_length());
4261 let mut score = 0.0;
4262 for token in query_tokens {
4263 let freq = query_token_counts_iter.next().expect_ok()? as f32;
4264 let idf = idf(scorer.num_docs_containing_token(token), scorer.num_docs());
4265 score += idf * (freq * (K1 + 1.0) / (freq + doc_norm));
4266 }
4267 if score > 0.0 {
4268 row_ids_builder.append_value(row_id);
4269 scores_builder.append_value(score);
4270 }
4271 }
4272
4273 let row_ids = row_ids_builder.finish();
4274 let scores = scores_builder.finish();
4275 let batch = RecordBatch::try_new(
4276 FTS_SCHEMA.clone(),
4277 vec![Arc::new(row_ids) as ArrayRef, Arc::new(scores) as ArrayRef],
4278 )?;
4279 Ok(batch)
4280}
4281
4282#[deprecated(
4283 note = "use `flat_bm25_search_stream_with_metrics` to record CPU compute \
4284 time on a metric handle; pass `None` for the old behavior"
4285)]
4286pub async fn flat_bm25_search_stream(
4287 input: SendableRecordBatchStream,
4288 doc_col: String,
4289 query: String,
4290 tokenizer: Box<dyn LanceTokenizer>,
4291 base_scorer: Option<MemBM25Scorer>,
4292 target_batch_size: usize,
4293) -> DataFusionResult<SendableRecordBatchStream> {
4294 flat_bm25_search_stream_with_metrics(
4295 input,
4296 doc_col,
4297 query,
4298 tokenizer,
4299 base_scorer,
4300 target_batch_size,
4301 None,
4302 )
4303 .await
4304}
4305
4306pub async fn flat_bm25_search_stream_with_metrics(
4312 input: SendableRecordBatchStream,
4313 doc_col: String,
4314 query: String,
4315 tokenizer: Box<dyn LanceTokenizer>,
4316 base_scorer: Option<MemBM25Scorer>,
4317 target_batch_size: usize,
4318 elapsed_compute: Option<Time>,
4319) -> DataFusionResult<SendableRecordBatchStream> {
4320 let mut tokenizer = tokenizer;
4321 let query_tokens = Arc::new(collect_query_tokens(&query, &mut tokenizer));
4322
4323 let input_schema = input.schema();
4324 let doc_col_idx = input_schema.index_of(&doc_col)?;
4325
4326 const ACCUMULATE_BYTES: usize = 256 * 1024;
4328 const SLICE_BYTES: usize = 512 * 1024;
4330
4331 let chunked = lance_arrow::stream::rechunk_stream_by_size(
4334 input,
4335 input_schema,
4336 ACCUMULATE_BYTES,
4337 SLICE_BYTES,
4338 );
4339
4340 let counted_input = tokenize_and_count(
4345 chunked,
4346 tokenizer,
4347 query_tokens.clone(),
4348 doc_col_idx,
4349 elapsed_compute.clone(),
4350 )
4351 .await?;
4352
4353 let scoring_start = std::time::Instant::now();
4356 let scorer = initialize_scorer(base_scorer.as_ref(), query_tokens.as_ref(), &counted_input);
4357 let scores = flat_bm25_score(query_tokens.as_ref(), &counted_input, &scorer)?;
4358 if let Some(t) = &elapsed_compute {
4359 t.add_duration(scoring_start.elapsed());
4360 }
4361
4362 let num_out_batches = scores.num_rows().div_ceil(target_batch_size);
4364 let mut batches = Vec::with_capacity(num_out_batches);
4365 for i in 0..num_out_batches {
4366 let start = i * target_batch_size;
4367 let len = (scores.num_rows() - start).min(target_batch_size);
4368 batches.push(Ok(scores.slice(start, len)));
4369 }
4370 Ok(Box::pin(RecordBatchStreamAdapter::new(
4371 FTS_SCHEMA.clone(),
4372 stream::iter(batches),
4373 )))
4374}
4375
4376pub fn is_phrase_query(query: &str) -> bool {
4377 query.starts_with('\"') && query.ends_with('\"')
4378}
4379
4380#[cfg(test)]
4381mod tests {
4382 use crate::scalar::inverted::document_tokenizer::DocType;
4383 use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
4384 use futures::stream;
4385 use lance_core::cache::LanceCache;
4386 use lance_core::utils::tempfile::TempObjDir;
4387 use lance_io::object_store::ObjectStore;
4388
4389 use crate::metrics::NoOpMetricsCollector;
4390 use crate::prefilter::NoFilter;
4391 use crate::scalar::ScalarIndex;
4392 use crate::scalar::inverted::builder::{
4393 InnerBuilder, InvertedIndexBuilder, PositionRecorder, inverted_list_schema,
4394 };
4395 use crate::scalar::inverted::encoding::{
4396 compress_positions, compress_posting_list_with_tail_codec,
4397 decompress_posting_list_with_tail_codec, encode_position_stream_block_into,
4398 };
4399 use crate::scalar::inverted::query::{FtsSearchParams, Operator};
4400 use crate::scalar::lance_format::LanceIndexStore;
4401 use arrow::array::{AsArray, LargeBinaryBuilder, ListBuilder, UInt32Builder};
4402 use arrow::datatypes::{Float32Type, UInt32Type};
4403 use arrow_array::{ArrayRef, Float32Array, RecordBatch, StringArray, UInt32Array, UInt64Array};
4404 use arrow_schema::{DataType, Field, Schema};
4405 use std::collections::HashMap;
4406 use std::sync::Arc;
4407
4408 use super::*;
4409
4410 async fn write_single_partition_index(
4411 store: Arc<LanceIndexStore>,
4412 params: InvertedIndexParams,
4413 token_set_format: TokenSetFormat,
4414 token: &str,
4415 row_id: u64,
4416 ) -> Result<Arc<InvertedIndex>> {
4417 let mut partition = InnerBuilder::new_with_format_version(
4418 0,
4419 false,
4420 token_set_format,
4421 InvertedListFormatVersion::V1,
4422 );
4423 partition.tokens.add(token.to_owned());
4424 let mut posting_list =
4425 PostingListBuilder::new_with_posting_tail_codec(false, PostingTailCodec::Fixed32);
4426 posting_list.add(0, PositionRecorder::Count(1));
4427 partition.posting_lists.push(posting_list);
4428 partition.docs.append(row_id, 1);
4429 partition.write(store.as_ref()).await?;
4430
4431 let metadata = HashMap::from([
4432 (
4433 "partitions".to_owned(),
4434 serde_json::to_string(&vec![0_u64]).unwrap(),
4435 ),
4436 ("params".to_owned(), serde_json::to_string(¶ms).unwrap()),
4437 (
4438 TOKEN_SET_FORMAT_KEY.to_owned(),
4439 token_set_format.to_string(),
4440 ),
4441 ]);
4442 let mut writer = store
4443 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
4444 .await?;
4445 writer.finish_with_metadata(metadata).await?;
4446
4447 InvertedIndex::load(store, None, &LanceCache::no_cache()).await
4448 }
4449
4450 fn empty_doc_stream() -> SendableRecordBatchStream {
4451 let schema = Arc::new(Schema::new(vec![
4452 Field::new("doc", DataType::Utf8, true),
4453 Field::new(ROW_ID, DataType::UInt64, false),
4454 ]));
4455 Box::pin(RecordBatchStreamAdapter::new(
4456 schema,
4457 stream::iter(Vec::<datafusion::error::Result<RecordBatch>>::new()),
4458 ))
4459 }
4460
4461 #[tokio::test]
4462 async fn test_posting_builder_remap() {
4463 let posting_tail_codec = PostingTailCodec::Fixed32;
4464 let mut builder =
4465 PostingListBuilder::new_with_posting_tail_codec(false, posting_tail_codec);
4466 let n = BLOCK_SIZE + 3;
4467 for i in 0..n {
4468 builder.add(i as u32, PositionRecorder::Count(1));
4469 }
4470 let removed = vec![5, 7];
4471 builder.remap(&removed);
4472
4473 let mut expected =
4474 PostingListBuilder::new_with_posting_tail_codec(false, posting_tail_codec);
4475 for i in 0..n - removed.len() {
4476 expected.add(i as u32, PositionRecorder::Count(1));
4477 }
4478 let expected_entries = expected.iter().collect::<Vec<_>>();
4479 let actual_entries = builder.iter().collect::<Vec<_>>();
4480 assert_eq!(actual_entries, expected_entries);
4481
4482 let batch = builder.to_batch(vec![1.0, 2.0]).unwrap();
4485 let (doc_ids, freqs) = decompress_posting_list_with_tail_codec(
4486 (n - removed.len()) as u32,
4487 batch[POSTING_COL]
4488 .as_list::<i32>()
4489 .value(0)
4490 .as_binary::<i64>(),
4491 posting_tail_codec,
4492 )
4493 .unwrap();
4494 assert!(
4495 doc_ids
4496 .iter()
4497 .zip(expected_entries.iter().map(|(doc_id, _, _)| doc_id))
4498 .all(|(a, b)| a == b)
4499 );
4500 assert!(
4501 freqs
4502 .iter()
4503 .zip(expected_entries.iter().map(|(_, freq, _)| freq))
4504 .all(|(a, b)| a == b)
4505 );
4506 }
4507
4508 #[test]
4509 fn test_posting_builder_size_tracking_matches_structure() {
4510 fn tracked_memory_size(builder: &PostingListBuilder) -> u64 {
4511 let encoded_blocks_size = builder
4512 .encoded_blocks
4513 .iter()
4514 .map(|encoded_blocks| std::mem::size_of::<EncodedBlocks>() + encoded_blocks.size())
4515 .sum::<usize>();
4516 let encoded_positions_size = builder
4517 .encoded_position_blocks
4518 .as_ref()
4519 .map(|positions| std::mem::size_of::<EncodedPositionBlocks>() + positions.size())
4520 .unwrap_or(0usize);
4521 (encoded_blocks_size
4522 + builder.tail_entries.capacity() * std::mem::size_of::<RawDocInfo>()
4523 + builder.tail_positions.size()
4524 + encoded_positions_size) as u64
4525 }
4526
4527 let mut builder = PostingListBuilder::new(true);
4528 for doc_id in 0..(BLOCK_SIZE + 5) as u32 {
4529 builder.add(
4530 doc_id,
4531 PositionRecorder::Position(smallvec::smallvec![1, 3, 5]),
4532 );
4533 }
4534
4535 assert_eq!(builder.size(), tracked_memory_size(&builder));
4536 }
4537
4538 #[test]
4539 fn test_posting_builder_flush_releases_tail_position_capacity() {
4540 let mut builder = PostingListBuilder::new(true);
4541 let positions = smallvec::SmallVec::<[u32; 2]>::from_vec((0..1024).collect());
4542 for doc_id in 0..BLOCK_SIZE as u32 {
4543 builder.add(doc_id, PositionRecorder::Position(positions.clone()));
4544 }
4545
4546 assert_eq!(builder.tail_positions.size(), 0);
4547 assert_eq!(builder.size(), {
4548 let encoded_blocks_size = builder
4549 .encoded_blocks
4550 .iter()
4551 .map(|encoded_blocks| std::mem::size_of::<EncodedBlocks>() + encoded_blocks.size())
4552 .sum::<usize>();
4553 let encoded_positions_size = builder
4554 .encoded_position_blocks
4555 .as_ref()
4556 .map(|positions| std::mem::size_of::<EncodedPositionBlocks>() + positions.size())
4557 .unwrap_or(0usize);
4558 (encoded_blocks_size
4559 + builder.tail_entries.capacity() * std::mem::size_of::<RawDocInfo>()
4560 + builder.tail_positions.size()
4561 + encoded_positions_size) as u64
4562 });
4563 }
4564
4565 #[test]
4566 fn test_posting_builder_streamed_positions_roundtrip() {
4567 let mut builder = PostingListBuilder::new(true);
4568 assert!(builder.add_occurrence(0, 1).unwrap());
4569 assert!(!builder.add_occurrence(0, 4).unwrap());
4570 assert!(!builder.add_occurrence(0, 9).unwrap());
4571 builder.finish_open_doc(0).unwrap();
4572
4573 assert!(builder.add_occurrence(2, 3).unwrap());
4574 builder.finish_open_doc(2).unwrap();
4575
4576 let entries = builder.iter().collect::<Vec<_>>();
4577 assert_eq!(
4578 entries,
4579 vec![
4580 (0_u32, 3_u32, Some(vec![1_u32, 4_u32, 9_u32])),
4581 (2_u32, 1_u32, Some(vec![3_u32])),
4582 ]
4583 );
4584 }
4585
4586 #[test]
4587 fn test_posting_builder_roundtrip_shared_positions() {
4588 let entries = vec![
4589 (0_u32, vec![1_u32, 5]),
4590 (2, vec![0, 4, 9]),
4591 (4, vec![7]),
4592 (8, vec![3, 10]),
4593 (13, vec![2, 11, 30]),
4594 ];
4595 let mut builder =
4596 PostingListBuilder::new_with_posting_tail_codec(true, PostingTailCodec::VarintDelta);
4597 for (doc_id, positions) in &entries {
4598 builder.add(
4599 *doc_id,
4600 PositionRecorder::Position(positions.clone().into()),
4601 );
4602 }
4603
4604 let batch = builder.to_batch(vec![1.0]).unwrap();
4605 assert!(batch.column_by_name(COMPRESSED_POSITION_COL).is_some());
4606 assert!(batch.column_by_name(POSITION_COL).is_none());
4607 assert_eq!(
4608 batch.schema_ref().metadata().get(POSTING_TAIL_CODEC_KEY),
4609 Some(&PostingTailCodec::VarintDelta.as_str().to_owned())
4610 );
4611 assert_eq!(
4612 batch.schema_ref().metadata().get(POSITIONS_LAYOUT_KEY),
4613 Some(&POSITIONS_LAYOUT_SHARED_STREAM_V2.to_owned())
4614 );
4615 assert_eq!(
4616 batch.schema_ref().metadata().get(POSITIONS_CODEC_KEY),
4617 Some(&PositionStreamCodec::PackedDelta.as_str().to_owned())
4618 );
4619
4620 let posting =
4621 PostingList::from_batch(&batch, Some(1.0), Some(entries.len() as u32)).unwrap();
4622 let actual = posting
4623 .iter()
4624 .map(|(doc_id, freq, positions)| {
4625 (doc_id as u32, freq, positions.unwrap().collect::<Vec<_>>())
4626 })
4627 .collect::<Vec<_>>();
4628 let expected = entries
4629 .iter()
4630 .map(|(doc_id, positions)| (*doc_id, positions.len() as u32, positions.clone()))
4631 .collect::<Vec<_>>();
4632 assert_eq!(actual, expected);
4633 }
4634
4635 #[test]
4636 fn test_posting_builder_roundtrip_legacy_positions() {
4637 let entries = vec![(0_u32, vec![1_u32, 5]), (2, vec![0, 4, 9]), (4, vec![7])];
4638 let mut builder =
4639 PostingListBuilder::new_with_posting_tail_codec(true, PostingTailCodec::Fixed32);
4640 for (doc_id, positions) in &entries {
4641 builder.add(
4642 *doc_id,
4643 PositionRecorder::Position(positions.clone().into()),
4644 );
4645 }
4646
4647 let batch = builder.to_batch(vec![1.0]).unwrap();
4648 assert!(batch.column_by_name(POSITION_COL).is_some());
4649 assert!(batch.column_by_name(COMPRESSED_POSITION_COL).is_none());
4650 assert_eq!(
4651 batch.schema_ref().metadata().get(POSTING_TAIL_CODEC_KEY),
4652 None
4653 );
4654 assert_eq!(
4655 batch.schema_ref().metadata().get(POSITIONS_LAYOUT_KEY),
4656 None
4657 );
4658 assert_eq!(batch.schema_ref().metadata().get(POSITIONS_CODEC_KEY), None);
4659
4660 let posting =
4661 PostingList::from_batch(&batch, Some(1.0), Some(entries.len() as u32)).unwrap();
4662 let actual = posting
4663 .iter()
4664 .map(|(doc_id, freq, positions)| {
4665 (doc_id as u32, freq, positions.unwrap().collect::<Vec<_>>())
4666 })
4667 .collect::<Vec<_>>();
4668 let expected = entries
4669 .iter()
4670 .map(|(doc_id, positions)| (*doc_id, positions.len() as u32, positions.clone()))
4671 .collect::<Vec<_>>();
4672 assert_eq!(actual, expected);
4673 }
4674
4675 #[test]
4676 fn test_resolve_fts_format_version_defaults_to_v1() {
4677 assert_eq!(
4678 resolve_fts_format_version(None).unwrap(),
4679 InvertedListFormatVersion::V1
4680 );
4681 assert_eq!(
4682 resolve_fts_format_version(Some("2")).unwrap(),
4683 InvertedListFormatVersion::V2
4684 );
4685 }
4686
4687 #[test]
4688 fn test_legacy_compressed_positions_still_readable() {
4689 let doc_ids = [1_u32, 3_u32];
4690 let frequencies = [2_u32, 3_u32];
4691 let posting = compress_posting_list_with_tail_codec(
4692 doc_ids.len(),
4693 doc_ids.iter(),
4694 frequencies.iter(),
4695 std::iter::once(1.0_f32),
4696 PostingTailCodec::Fixed32,
4697 )
4698 .unwrap();
4699
4700 let mut posting_builder = ListBuilder::new(LargeBinaryBuilder::new());
4701 for idx in 0..posting.len() {
4702 posting_builder.values().append_value(posting.value(idx));
4703 }
4704 posting_builder.append(true);
4705
4706 let mut positions_builder = ListBuilder::new(ListBuilder::new(LargeBinaryBuilder::new()));
4707 for positions in [vec![1_u32, 5_u32], vec![0_u32, 4_u32, 9_u32]] {
4708 let compressed = compress_positions(&positions).unwrap();
4709 let doc_builder = positions_builder.values();
4710 for idx in 0..compressed.len() {
4711 doc_builder.values().append_value(compressed.value(idx));
4712 }
4713 doc_builder.append(true);
4714 }
4715 positions_builder.append(true);
4716
4717 let schema = Arc::new(Schema::new(vec![
4718 Field::new(
4719 POSTING_COL,
4720 DataType::List(Arc::new(Field::new("item", DataType::LargeBinary, true))),
4721 false,
4722 ),
4723 Field::new(MAX_SCORE_COL, DataType::Float32, false),
4724 Field::new(LENGTH_COL, DataType::UInt32, false),
4725 Field::new(
4726 POSITION_COL,
4727 DataType::List(Arc::new(Field::new(
4728 "item",
4729 DataType::List(Arc::new(Field::new("item", DataType::LargeBinary, true))),
4730 true,
4731 ))),
4732 false,
4733 ),
4734 ]));
4735 let batch = RecordBatch::try_new(
4736 schema,
4737 vec![
4738 Arc::new(posting_builder.finish()) as ArrayRef,
4739 Arc::new(Float32Array::from(vec![1.0])) as ArrayRef,
4740 Arc::new(UInt32Array::from(vec![doc_ids.len() as u32])) as ArrayRef,
4741 Arc::new(positions_builder.finish()) as ArrayRef,
4742 ],
4743 )
4744 .unwrap();
4745
4746 let posting =
4747 PostingList::from_batch(&batch, Some(1.0), Some(doc_ids.len() as u32)).unwrap();
4748 let actual = posting
4749 .iter()
4750 .map(|(doc_id, freq, positions)| {
4751 (doc_id as u32, freq, positions.unwrap().collect::<Vec<_>>())
4752 })
4753 .collect::<Vec<_>>();
4754 assert_eq!(actual, vec![(1, 2, vec![1, 5]), (3, 3, vec![0, 4, 9]),]);
4755 }
4756
4757 #[test]
4758 fn test_shared_stream_v2_without_codec_still_readable() {
4759 let doc_ids = [1_u32, 3_u32];
4760 let frequencies = [2_u32, 3_u32];
4761 let posting = compress_posting_list_with_tail_codec(
4762 doc_ids.len(),
4763 doc_ids.iter(),
4764 frequencies.iter(),
4765 std::iter::once(1.0_f32),
4766 PostingTailCodec::Fixed32,
4767 )
4768 .unwrap();
4769
4770 let mut posting_builder = ListBuilder::new(LargeBinaryBuilder::new());
4771 for idx in 0..posting.len() {
4772 posting_builder.values().append_value(posting.value(idx));
4773 }
4774 posting_builder.append(true);
4775
4776 let positions = vec![1_u32, 5_u32, 0_u32, 4_u32, 9_u32];
4777 let mut encoded_positions = Vec::new();
4778 encode_position_stream_block_into(
4779 &positions,
4780 &frequencies,
4781 PositionStreamCodec::VarintDocDelta,
4782 &mut encoded_positions,
4783 )
4784 .unwrap();
4785
4786 let mut position_offsets = ListBuilder::new(UInt32Builder::new());
4787 position_offsets.values().append_value(0);
4788 position_offsets.append(true);
4789
4790 let schema = Arc::new(Schema::new_with_metadata(
4791 vec![
4792 Field::new(
4793 POSTING_COL,
4794 DataType::List(Arc::new(Field::new("item", DataType::LargeBinary, true))),
4795 false,
4796 ),
4797 Field::new(MAX_SCORE_COL, DataType::Float32, false),
4798 Field::new(LENGTH_COL, DataType::UInt32, false),
4799 Field::new(COMPRESSED_POSITION_COL, DataType::LargeBinary, false),
4800 Field::new(
4801 POSITION_BLOCK_OFFSET_COL,
4802 DataType::List(Arc::new(Field::new("item", DataType::UInt32, true))),
4803 false,
4804 ),
4805 ],
4806 HashMap::from([(
4807 POSITIONS_LAYOUT_KEY.to_owned(),
4808 POSITIONS_LAYOUT_SHARED_STREAM_V2.to_owned(),
4809 )]),
4810 ));
4811 let batch = RecordBatch::try_new(
4812 schema,
4813 vec![
4814 Arc::new(posting_builder.finish()) as ArrayRef,
4815 Arc::new(Float32Array::from(vec![1.0])) as ArrayRef,
4816 Arc::new(UInt32Array::from(vec![doc_ids.len() as u32])) as ArrayRef,
4817 Arc::new(arrow_array::LargeBinaryArray::from(vec![Some(
4818 encoded_positions.as_slice(),
4819 )])) as ArrayRef,
4820 Arc::new(position_offsets.finish()) as ArrayRef,
4821 ],
4822 )
4823 .unwrap();
4824
4825 let posting =
4826 PostingList::from_batch(&batch, Some(1.0), Some(doc_ids.len() as u32)).unwrap();
4827 let actual = posting
4828 .iter()
4829 .map(|(doc_id, freq, positions)| {
4830 (doc_id as u32, freq, positions.unwrap().collect::<Vec<_>>())
4831 })
4832 .collect::<Vec<_>>();
4833 assert_eq!(actual, vec![(1, 2, vec![1, 5]), (3, 3, vec![0, 4, 9]),]);
4834 }
4835
4836 #[test]
4837 fn test_shared_position_stream_is_smaller_for_sparse_positions() {
4838 let mut builder =
4839 PostingListBuilder::new_with_posting_tail_codec(true, PostingTailCodec::VarintDelta);
4840 let mut legacy_positions = Vec::with_capacity(BLOCK_SIZE * 4);
4841 for doc_id in 0..(BLOCK_SIZE * 4) as u32 {
4842 let mut positions = vec![doc_id * 3 + 1];
4843 if doc_id % 8 == 0 {
4844 positions.push(doc_id * 3 + 2);
4845 }
4846 builder.add(doc_id, PositionRecorder::Position(positions.clone().into()));
4847 legacy_positions.push(positions);
4848 }
4849
4850 let batch = builder.to_batch(vec![1.0; 4]).unwrap();
4851 let shared_positions_size = batch[COMPRESSED_POSITION_COL].get_buffer_memory_size()
4852 + batch[POSITION_BLOCK_OFFSET_COL].get_buffer_memory_size();
4853
4854 let mut positions_builder = ListBuilder::new(ListBuilder::new(LargeBinaryBuilder::new()));
4855 for positions in legacy_positions {
4856 let compressed = compress_positions(&positions).unwrap();
4857 let doc_builder = positions_builder.values();
4858 for idx in 0..compressed.len() {
4859 doc_builder.values().append_value(compressed.value(idx));
4860 }
4861 doc_builder.append(true);
4862 }
4863 positions_builder.append(true);
4864 let legacy_positions_size = positions_builder.finish().get_buffer_memory_size();
4865
4866 assert!(
4867 shared_positions_size < legacy_positions_size,
4868 "expected shared position stream to be smaller than legacy per-doc storage, shared={shared_positions_size}, legacy={legacy_positions_size}",
4869 );
4870 }
4871
4872 #[test]
4873 fn test_posting_list_batch_matches_docset_scoring() {
4874 let mut docs = DocSet::default();
4875 let num_docs = BLOCK_SIZE + 3;
4876 for doc_id in 0..num_docs as u32 {
4877 docs.append(doc_id as u64, doc_id % 7 + 1);
4878 }
4879
4880 let doc_ids = (0..num_docs as u32).collect::<Vec<_>>();
4881 let freqs = doc_ids
4882 .iter()
4883 .map(|doc_id| doc_id % 5 + 1)
4884 .collect::<Vec<_>>();
4885
4886 let mut builder_scores = PostingListBuilder::new(false);
4887 let mut builder_docs = PostingListBuilder::new(false);
4888 for (&doc_id, &freq) in doc_ids.iter().zip(freqs.iter()) {
4889 builder_scores.add(doc_id, PositionRecorder::Count(freq));
4890 builder_docs.add(doc_id, PositionRecorder::Count(freq));
4891 }
4892
4893 let block_max_scores = docs.calculate_block_max_scores(doc_ids.iter(), freqs.iter());
4894 let batch_scores = builder_scores.to_batch(block_max_scores).unwrap();
4895 let batch_docs = builder_docs
4896 .to_batch_with_docs(&docs, inverted_list_schema(false))
4897 .unwrap();
4898
4899 let scores_posting = batch_scores[POSTING_COL].as_list::<i32>().value(0);
4900 let scores_posting = scores_posting.as_binary::<i64>();
4901 let docs_posting = batch_docs[POSTING_COL].as_list::<i32>().value(0);
4902 let docs_posting = docs_posting.as_binary::<i64>();
4903 assert_eq!(scores_posting, docs_posting);
4904
4905 let score_left = batch_scores[MAX_SCORE_COL]
4906 .as_primitive::<Float32Type>()
4907 .value(0);
4908 let score_right = batch_docs[MAX_SCORE_COL]
4909 .as_primitive::<Float32Type>()
4910 .value(0);
4911 assert!((score_left - score_right).abs() < 1e-6);
4912
4913 let len_left = batch_scores[LENGTH_COL]
4914 .as_primitive::<UInt32Type>()
4915 .value(0);
4916 let len_right = batch_docs[LENGTH_COL].as_primitive::<UInt32Type>().value(0);
4917 assert_eq!(len_left, len_right);
4918 }
4919
4920 #[tokio::test]
4921 async fn test_remap_to_empty_posting_list() {
4922 let tmpdir = TempObjDir::default();
4923 let store = Arc::new(LanceIndexStore::new(
4924 ObjectStore::local().into(),
4925 tmpdir.clone(),
4926 Arc::new(LanceCache::no_cache()),
4927 ));
4928
4929 let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default());
4930
4931 builder.tokens.add("lance".to_owned());
4936 builder.tokens.add("lake".to_owned());
4937 builder.posting_lists.push(PostingListBuilder::new(false));
4938 builder.posting_lists.push(PostingListBuilder::new(false));
4939 builder.posting_lists[0].add(0, PositionRecorder::Count(1));
4940 builder.posting_lists[1].add(1, PositionRecorder::Count(2));
4941 builder.posting_lists[1].add(2, PositionRecorder::Count(3));
4942 builder.docs.append(0, 1);
4943 builder.docs.append(1, 1);
4944 builder.docs.append(2, 1);
4945 builder.write(store.as_ref()).await.unwrap();
4946
4947 let index = InvertedPartition::load(
4948 store.clone(),
4949 0,
4950 None,
4951 &LanceCache::no_cache(),
4952 TokenSetFormat::default(),
4953 )
4954 .await
4955 .unwrap();
4956 let mut builder = index.into_builder().await.unwrap();
4957
4958 let mapping = HashMap::from([(0, None), (2, Some(3))]);
4959 builder.remap(&mapping).await.unwrap();
4960
4961 assert_eq!(builder.tokens.len(), 1);
4963 assert_eq!(builder.tokens.get("lake"), Some(0));
4964 assert_eq!(builder.posting_lists.len(), 1);
4965 assert_eq!(builder.posting_lists[0].len(), 2);
4966 assert_eq!(builder.docs.len(), 2);
4967 assert_eq!(builder.docs.row_id(0), 1);
4968 assert_eq!(builder.docs.row_id(1), 3);
4969
4970 builder.write(store.as_ref()).await.unwrap();
4971
4972 let mapping = HashMap::from([(1, None), (3, None)]);
4974 builder.remap(&mapping).await.unwrap();
4975
4976 assert_eq!(builder.tokens.len(), 0);
4977 assert_eq!(builder.posting_lists.len(), 0);
4978 assert_eq!(builder.docs.len(), 0);
4979
4980 builder.write(store.as_ref()).await.unwrap();
4981 }
4982
4983 #[tokio::test]
4984 async fn test_posting_cache_conflict_across_partitions() {
4985 let tmpdir = TempObjDir::default();
4986 let store = Arc::new(LanceIndexStore::new(
4987 ObjectStore::local().into(),
4988 tmpdir.clone(),
4989 Arc::new(LanceCache::no_cache()),
4990 ));
4991
4992 let mut builder1 = InnerBuilder::new(0, false, TokenSetFormat::default());
4994 builder1.tokens.add("test".to_owned());
4995 builder1.posting_lists.push(PostingListBuilder::new(false));
4996 builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
4997 builder1.docs.append(100, 1); builder1.write(store.as_ref()).await.unwrap();
4999
5000 let mut builder2 = InnerBuilder::new(1, false, TokenSetFormat::default());
5002 builder2.tokens.add("test".to_owned()); builder2.posting_lists.push(PostingListBuilder::new(false));
5004 builder2.posting_lists[0].add(0, PositionRecorder::Count(2));
5005 builder2.posting_lists[0].add(1, PositionRecorder::Count(1));
5006 builder2.posting_lists[0].add(2, PositionRecorder::Count(3));
5007 builder2.posting_lists[0].add(3, PositionRecorder::Count(1));
5008 builder2.docs.append(200, 2); builder2.docs.append(201, 1); builder2.docs.append(202, 3); builder2.docs.append(203, 1); builder2.write(store.as_ref()).await.unwrap();
5013
5014 let metadata = std::collections::HashMap::from_iter(vec![
5016 (
5017 "partitions".to_owned(),
5018 serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
5019 ),
5020 (
5021 "params".to_owned(),
5022 serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
5023 ),
5024 (
5025 TOKEN_SET_FORMAT_KEY.to_owned(),
5026 TokenSetFormat::default().to_string(),
5027 ),
5028 ]);
5029 let mut writer = store
5030 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5031 .await
5032 .unwrap();
5033 writer.finish_with_metadata(metadata).await.unwrap();
5034
5035 let cache = Arc::new(LanceCache::with_capacity(4096));
5037 let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
5038 .await
5039 .unwrap();
5040
5041 assert_eq!(index.partitions.len(), 2);
5043 assert_eq!(index.partitions[0].tokens.len(), 1);
5044 assert_eq!(index.partitions[1].tokens.len(), 1);
5045
5046 if index.partitions[0].id() == 0 {
5051 assert_eq!(index.partitions[0].inverted_list.posting_len(0), 1);
5053 assert_eq!(index.partitions[1].inverted_list.posting_len(0), 4);
5054 assert_eq!(index.partitions[0].docs.len(), 1);
5055 assert_eq!(index.partitions[1].docs.len(), 4);
5056 } else {
5057 assert_eq!(index.partitions[0].inverted_list.posting_len(0), 4);
5059 assert_eq!(index.partitions[1].inverted_list.posting_len(0), 1);
5060 assert_eq!(index.partitions[0].docs.len(), 4);
5061 assert_eq!(index.partitions[1].docs.len(), 1);
5062 }
5063
5064 index.prewarm().await.unwrap();
5066
5067 let tokens = Arc::new(Tokens::new(vec!["test".to_string()], DocType::Text));
5068 let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
5069 let prefilter = Arc::new(NoFilter);
5070 let metrics = Arc::new(NoOpMetricsCollector);
5071
5072 let (row_ids, scores) = index
5073 .bm25_search(tokens, params, Operator::Or, prefilter, metrics, None)
5074 .await
5075 .unwrap();
5076
5077 assert_eq!(row_ids.len(), 5, "row_ids: {:?}", row_ids);
5080 assert!(!row_ids.is_empty(), "Should find at least some documents");
5081 assert_eq!(row_ids.len(), scores.len());
5082
5083 for &score in &scores {
5085 assert!(score > 0.0, "All scores should be positive");
5086 }
5087
5088 assert!(
5090 row_ids.contains(&100),
5091 "Should contain row_id from partition 0"
5092 );
5093 assert!(
5094 row_ids.iter().any(|&id| id >= 200),
5095 "Should contain row_id from partition 1"
5096 );
5097 }
5098
5099 #[tokio::test]
5100 async fn test_modern_prewarm_shrinks_cached_posting_buffers() {
5101 let tmpdir = TempObjDir::default();
5102 let store = Arc::new(LanceIndexStore::new(
5103 ObjectStore::local().into(),
5104 tmpdir.clone(),
5105 Arc::new(LanceCache::no_cache()),
5106 ));
5107
5108 let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default());
5109 builder.tokens.add("alpha".to_owned());
5110 builder.tokens.add("beta".to_owned());
5111 builder.posting_lists.push(PostingListBuilder::new(false));
5112 builder.posting_lists.push(PostingListBuilder::new(false));
5113 builder.posting_lists[0].add(0, PositionRecorder::Count(1));
5114 builder.posting_lists[0].add(1, PositionRecorder::Count(2));
5115 builder.posting_lists[1].add(2, PositionRecorder::Count(3));
5116 builder.posting_lists[1].add(3, PositionRecorder::Count(4));
5117 builder.docs.append(100, 1);
5118 builder.docs.append(101, 2);
5119 builder.docs.append(102, 3);
5120 builder.docs.append(103, 4);
5121 builder.write(store.as_ref()).await.unwrap();
5122
5123 let metadata = std::collections::HashMap::from_iter(vec![
5124 (
5125 "partitions".to_owned(),
5126 serde_json::to_string(&vec![0u64]).unwrap(),
5127 ),
5128 (
5129 "params".to_owned(),
5130 serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
5131 ),
5132 (
5133 TOKEN_SET_FORMAT_KEY.to_owned(),
5134 TokenSetFormat::default().to_string(),
5135 ),
5136 ]);
5137 let mut writer = store
5138 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5139 .await
5140 .unwrap();
5141 writer.finish_with_metadata(metadata).await.unwrap();
5142
5143 let cache = Arc::new(LanceCache::with_capacity(4096));
5144 let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
5145 .await
5146 .unwrap();
5147 let inverted_list = &index.partitions[0].inverted_list;
5148 assert!(
5149 inverted_list.offsets.is_none(),
5150 "test should use modern posting layout"
5151 );
5152
5153 inverted_list.prewarm_posting_lists(false).await.unwrap();
5154
5155 let alpha = inverted_list
5156 .index_cache
5157 .get_with_key(&PostingListKey { token_id: 0 })
5158 .await
5159 .unwrap();
5160 let beta = inverted_list
5161 .index_cache
5162 .get_with_key(&PostingListKey { token_id: 1 })
5163 .await
5164 .unwrap();
5165
5166 let PostingList::Compressed(alpha) = alpha.as_ref() else {
5167 panic!("expected compressed posting list for token 0");
5168 };
5169 let PostingList::Compressed(beta) = beta.as_ref() else {
5170 panic!("expected compressed posting list for token 1");
5171 };
5172
5173 assert_ne!(
5174 alpha.blocks.values().as_ptr(),
5175 beta.blocks.values().as_ptr(),
5176 "prewarm should not leave cached posting lists sharing the same values buffer"
5177 );
5178 }
5179
5180 #[tokio::test]
5181 async fn test_prewarm_with_positions_populates_separate_position_cache() {
5182 let tmpdir = TempObjDir::default();
5183 let store = Arc::new(LanceIndexStore::new(
5184 ObjectStore::local().into(),
5185 tmpdir.clone(),
5186 Arc::new(LanceCache::no_cache()),
5187 ));
5188
5189 let mut builder = InnerBuilder::new_with_format_version(
5190 0,
5191 true,
5192 TokenSetFormat::default(),
5193 InvertedListFormatVersion::V1,
5194 );
5195 builder.tokens.add("hello".to_owned());
5196 builder.tokens.add("world".to_owned());
5197 builder
5198 .posting_lists
5199 .push(PostingListBuilder::new_with_posting_tail_codec(
5200 true,
5201 PostingTailCodec::Fixed32,
5202 ));
5203 builder
5204 .posting_lists
5205 .push(PostingListBuilder::new_with_posting_tail_codec(
5206 true,
5207 PostingTailCodec::Fixed32,
5208 ));
5209 builder.posting_lists[0].add(0, PositionRecorder::Position(vec![0].into()));
5210 builder.posting_lists[1].add(0, PositionRecorder::Position(vec![1].into()));
5211 builder.posting_lists[0].add(1, PositionRecorder::Position(vec![0].into()));
5212 builder.posting_lists[1].add(1, PositionRecorder::Position(vec![2].into()));
5213 builder.docs.append(100, 2);
5214 builder.docs.append(101, 2);
5215 builder.write(store.as_ref()).await.unwrap();
5216
5217 let metadata = std::collections::HashMap::from_iter(vec![
5218 (
5219 "partitions".to_owned(),
5220 serde_json::to_string(&vec![0_u64]).unwrap(),
5221 ),
5222 (
5223 "params".to_owned(),
5224 serde_json::to_string(&InvertedIndexParams::default().with_position(true)).unwrap(),
5225 ),
5226 (
5227 TOKEN_SET_FORMAT_KEY.to_owned(),
5228 TokenSetFormat::default().to_string(),
5229 ),
5230 ]);
5231 let mut writer = store
5232 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5233 .await
5234 .unwrap();
5235 writer.finish_with_metadata(metadata).await.unwrap();
5236
5237 let cache = Arc::new(LanceCache::with_capacity(4096));
5238 let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
5239 .await
5240 .unwrap();
5241
5242 index
5243 .prewarm_with_options(&FtsPrewarmOptions::new().with_position(true))
5244 .await
5245 .unwrap();
5246
5247 let inverted_list = &index.partitions[0].inverted_list;
5248 let posting = inverted_list
5249 .index_cache
5250 .get_with_key(&PostingListKey { token_id: 0 })
5251 .await
5252 .unwrap();
5253 assert!(
5254 !posting.has_position(),
5255 "posting cache should remain positions-free after prewarm"
5256 );
5257
5258 let positions = inverted_list
5259 .index_cache
5260 .get_with_key(&PositionKey { token_id: 0 })
5261 .await
5262 .unwrap();
5263 assert!(
5264 matches!(
5265 positions.as_ref().0,
5266 CompressedPositionStorage::LegacyPerDoc(_)
5267 ),
5268 "positions should be stored in the dedicated position cache"
5269 );
5270 }
5271
5272 #[tokio::test]
5273 async fn test_prewarm_with_v2_positions_preserves_shared_stream_codec() {
5274 let tmpdir = TempObjDir::default();
5275 let store = Arc::new(LanceIndexStore::new(
5276 ObjectStore::local().into(),
5277 tmpdir.clone(),
5278 Arc::new(LanceCache::no_cache()),
5279 ));
5280
5281 let format_version = InvertedListFormatVersion::V2;
5282 let posting_tail_codec = format_version.posting_tail_codec();
5283 let mut builder = InnerBuilder::new_with_format_version(
5284 0,
5285 true,
5286 TokenSetFormat::default(),
5287 format_version,
5288 );
5289 builder.tokens.add("body".to_owned());
5290
5291 let mut posting_list =
5292 PostingListBuilder::new_with_posting_tail_codec(true, posting_tail_codec);
5293 let expected = (0..(BLOCK_SIZE + 5) as u32)
5294 .map(|doc_id| {
5295 let positions = vec![doc_id % 3, doc_id % 3 + 2, doc_id % 3 + 5];
5296 posting_list.add(doc_id, PositionRecorder::Position(positions.clone().into()));
5297 builder.docs.append(30_000 + doc_id as u64, 20 + doc_id % 7);
5298 (doc_id, positions.len() as u32, positions)
5299 })
5300 .collect::<Vec<_>>();
5301 builder.posting_lists.push(posting_list);
5302 builder.write(store.as_ref()).await.unwrap();
5303
5304 let metadata = HashMap::from([
5305 (
5306 "partitions".to_owned(),
5307 serde_json::to_string(&vec![0_u64]).unwrap(),
5308 ),
5309 (
5310 "params".to_owned(),
5311 serde_json::to_string(&InvertedIndexParams::default().with_position(true)).unwrap(),
5312 ),
5313 (
5314 TOKEN_SET_FORMAT_KEY.to_owned(),
5315 TokenSetFormat::default().to_string(),
5316 ),
5317 (
5318 POSTING_TAIL_CODEC_KEY.to_owned(),
5319 posting_tail_codec.as_str().to_owned(),
5320 ),
5321 (
5322 POSITIONS_LAYOUT_KEY.to_owned(),
5323 POSITIONS_LAYOUT_SHARED_STREAM_V2.to_owned(),
5324 ),
5325 (
5326 POSITIONS_CODEC_KEY.to_owned(),
5327 PositionStreamCodec::PackedDelta.as_str().to_owned(),
5328 ),
5329 ]);
5330 let mut writer = store
5331 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5332 .await
5333 .unwrap();
5334 writer.finish_with_metadata(metadata).await.unwrap();
5335
5336 let cache = Arc::new(LanceCache::with_capacity(4096));
5337 let index = InvertedIndex::load(store, None, cache.as_ref())
5338 .await
5339 .unwrap();
5340 index
5341 .prewarm_with_options(&FtsPrewarmOptions::new().with_position(true))
5342 .await
5343 .unwrap();
5344
5345 let actual = index.partitions[0]
5346 .inverted_list
5347 .posting_list(0, true, &NoOpMetricsCollector)
5348 .await
5349 .unwrap()
5350 .iter()
5351 .map(|(doc_id, freq, positions)| {
5352 (doc_id as u32, freq, positions.unwrap().collect::<Vec<_>>())
5353 })
5354 .collect::<Vec<_>>();
5355
5356 assert_eq!(actual, expected);
5357 }
5358
5359 #[test]
5360 fn test_block_max_scores_capacity_matches_block_count() {
5361 let mut docs = DocSet::default();
5362 let num_docs = BLOCK_SIZE * 3 + 7;
5363 let doc_ids = (0..num_docs as u32).collect::<Vec<_>>();
5364 for doc_id in &doc_ids {
5365 docs.append(*doc_id as u64, 1);
5366 }
5367
5368 let freqs = vec![1_u32; doc_ids.len()];
5369 let block_max_scores = docs.calculate_block_max_scores(doc_ids.iter(), freqs.iter());
5370 let expected_blocks = doc_ids.len().div_ceil(BLOCK_SIZE);
5371
5372 assert_eq!(block_max_scores.len(), expected_blocks);
5373 assert_eq!(block_max_scores.capacity(), expected_blocks);
5374 }
5375
5376 #[tokio::test]
5377 async fn test_bm25_search_uses_global_idf() {
5378 let tmpdir = TempObjDir::default();
5379 let store = Arc::new(LanceIndexStore::new(
5380 ObjectStore::local().into(),
5381 tmpdir.clone(),
5382 Arc::new(LanceCache::no_cache()),
5383 ));
5384
5385 let mut builder0 = InnerBuilder::new(0, false, TokenSetFormat::default());
5387 builder0.tokens.add("alpha".to_owned());
5388 builder0.tokens.add("beta".to_owned());
5389 builder0.posting_lists.push(PostingListBuilder::new(false));
5390 builder0.posting_lists.push(PostingListBuilder::new(false));
5391 builder0.posting_lists[0].add(0, PositionRecorder::Count(1));
5392 builder0.posting_lists[1].add(1, PositionRecorder::Count(1));
5393 builder0.posting_lists[1].add(2, PositionRecorder::Count(1));
5394 builder0.docs.append(100, 1);
5395 builder0.docs.append(101, 1);
5396 builder0.docs.append(102, 1);
5397 builder0.write(store.as_ref()).await.unwrap();
5398
5399 let mut builder1 = InnerBuilder::new(1, false, TokenSetFormat::default());
5401 builder1.tokens.add("alpha".to_owned());
5402 builder1.posting_lists.push(PostingListBuilder::new(false));
5403 builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
5404 builder1.docs.append(200, 1);
5405 builder1.write(store.as_ref()).await.unwrap();
5406
5407 let metadata = std::collections::HashMap::from_iter(vec![
5408 (
5409 "partitions".to_owned(),
5410 serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
5411 ),
5412 (
5413 "params".to_owned(),
5414 serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
5415 ),
5416 (
5417 TOKEN_SET_FORMAT_KEY.to_owned(),
5418 TokenSetFormat::default().to_string(),
5419 ),
5420 ]);
5421 let mut writer = store
5422 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5423 .await
5424 .unwrap();
5425 writer.finish_with_metadata(metadata).await.unwrap();
5426
5427 let cache = Arc::new(LanceCache::with_capacity(4096));
5428 let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
5429 .await
5430 .unwrap();
5431
5432 let tokens = Arc::new(Tokens::new(vec!["alpha".to_string()], DocType::Text));
5433 let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
5434 let prefilter = Arc::new(NoFilter);
5435 let metrics = Arc::new(NoOpMetricsCollector);
5436
5437 let (row_ids, scores) = index
5438 .bm25_search(tokens, params, Operator::Or, prefilter, metrics, None)
5439 .await
5440 .unwrap();
5441
5442 assert_eq!(row_ids.len(), 2);
5443 assert!(row_ids.contains(&100));
5444 assert!(row_ids.contains(&200));
5445 assert_eq!(row_ids.len(), scores.len());
5446
5447 let expected_idf = idf(2, 4);
5448 for score in scores {
5449 assert!(
5450 (score - expected_idf).abs() < 1e-6,
5451 "score: {}, expected: {}",
5452 score,
5453 expected_idf
5454 );
5455 }
5456 }
5457
5458 #[tokio::test]
5459 async fn test_phrase_query_reads_legacy_per_doc_positions() {
5460 let tmpdir = TempObjDir::default();
5461 let store = Arc::new(LanceIndexStore::new(
5462 ObjectStore::local().into(),
5463 tmpdir.clone(),
5464 Arc::new(LanceCache::no_cache()),
5465 ));
5466
5467 let mut builder = InnerBuilder::new_with_format_version(
5468 0,
5469 true,
5470 TokenSetFormat::default(),
5471 InvertedListFormatVersion::V1,
5472 );
5473 builder.tokens.add("hello".to_owned());
5474 builder.tokens.add("world".to_owned());
5475 builder
5476 .posting_lists
5477 .push(PostingListBuilder::new_with_posting_tail_codec(
5478 true,
5479 PostingTailCodec::Fixed32,
5480 ));
5481 builder
5482 .posting_lists
5483 .push(PostingListBuilder::new_with_posting_tail_codec(
5484 true,
5485 PostingTailCodec::Fixed32,
5486 ));
5487 builder.posting_lists[0].add(0, PositionRecorder::Position(vec![0].into()));
5488 builder.posting_lists[1].add(0, PositionRecorder::Position(vec![1].into()));
5489 builder.posting_lists[0].add(1, PositionRecorder::Position(vec![0].into()));
5490 builder.posting_lists[1].add(1, PositionRecorder::Position(vec![2].into()));
5491 builder.docs.append(100, 2);
5492 builder.docs.append(101, 2);
5493 builder.write(store.as_ref()).await.unwrap();
5494
5495 let metadata = std::collections::HashMap::from_iter(vec![
5496 (
5497 "partitions".to_owned(),
5498 serde_json::to_string(&vec![0_u64]).unwrap(),
5499 ),
5500 (
5501 "params".to_owned(),
5502 serde_json::to_string(&InvertedIndexParams::default().with_position(true)).unwrap(),
5503 ),
5504 (
5505 TOKEN_SET_FORMAT_KEY.to_owned(),
5506 TokenSetFormat::default().to_string(),
5507 ),
5508 ]);
5509 let mut writer = store
5510 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5511 .await
5512 .unwrap();
5513 writer.finish_with_metadata(metadata).await.unwrap();
5514
5515 let cache = Arc::new(LanceCache::with_capacity(4096));
5516 let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
5517 .await
5518 .unwrap();
5519
5520 let tokens = Arc::new(Tokens::new(
5521 vec!["hello".to_owned(), "world".to_owned()],
5522 DocType::Text,
5523 ));
5524 let params = Arc::new(
5525 FtsSearchParams::new()
5526 .with_limit(Some(10))
5527 .with_phrase_slop(Some(0)),
5528 );
5529 let prefilter = Arc::new(NoFilter);
5530 let metrics = Arc::new(NoOpMetricsCollector);
5531
5532 let (row_ids, _scores) = index
5533 .bm25_search(tokens, params, Operator::And, prefilter, metrics, None)
5534 .await
5535 .unwrap();
5536
5537 assert_eq!(row_ids, vec![100]);
5538 }
5539
5540 #[tokio::test]
5541 async fn test_update_preserves_loaded_v2_format_version() -> Result<()> {
5542 let src_dir = TempObjDir::default();
5543 let dest_dir = TempObjDir::default();
5544 let src_store = Arc::new(LanceIndexStore::new(
5545 ObjectStore::local().into(),
5546 src_dir.clone(),
5547 Arc::new(LanceCache::no_cache()),
5548 ));
5549 let dest_store = Arc::new(LanceIndexStore::new(
5550 ObjectStore::local().into(),
5551 dest_dir.clone(),
5552 Arc::new(LanceCache::no_cache()),
5553 ));
5554
5555 let format_version = InvertedListFormatVersion::V2;
5556 let posting_tail_codec = format_version.posting_tail_codec();
5557 let mut partition = InnerBuilder::new_with_format_version(
5558 0,
5559 false,
5560 TokenSetFormat::default(),
5561 format_version,
5562 );
5563 partition.tokens.add("hello".to_owned());
5564 let mut posting_list =
5565 PostingListBuilder::new_with_posting_tail_codec(false, posting_tail_codec);
5566 posting_list.add(0, PositionRecorder::Count(1));
5567 partition.posting_lists.push(posting_list);
5568 partition.docs.append(100, 1);
5569 partition.write(src_store.as_ref()).await?;
5570
5571 let metadata = HashMap::from([
5572 (
5573 "partitions".to_owned(),
5574 serde_json::to_string(&vec![0_u64]).unwrap(),
5575 ),
5576 (
5577 "params".to_owned(),
5578 serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
5579 ),
5580 (
5581 TOKEN_SET_FORMAT_KEY.to_owned(),
5582 TokenSetFormat::default().to_string(),
5583 ),
5584 (
5585 POSTING_TAIL_CODEC_KEY.to_owned(),
5586 posting_tail_codec.as_str().to_owned(),
5587 ),
5588 ]);
5589 let mut writer = src_store
5590 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5591 .await
5592 .unwrap();
5593 writer.finish_with_metadata(metadata).await.unwrap();
5594
5595 let index = InvertedIndex::load(src_store, None, &LanceCache::no_cache()).await?;
5596 assert_eq!(index.index_version(), format_version.index_version());
5597
5598 let schema = Arc::new(Schema::new(vec![
5599 Field::new("doc", DataType::Utf8, true),
5600 Field::new(ROW_ID, DataType::UInt64, false),
5601 ]));
5602 let docs = Arc::new(StringArray::from(vec![Some("hello again")]));
5603 let row_ids = Arc::new(UInt64Array::from(vec![101u64]));
5604 let batch = RecordBatch::try_new(schema.clone(), vec![docs, row_ids])?;
5605 let stream = RecordBatchStreamAdapter::new(schema, stream::iter(vec![Ok(batch)]));
5606 let created = index
5607 .update(Box::pin(stream), dest_store.as_ref(), None)
5608 .await?;
5609
5610 assert_eq!(created.index_version, format_version.index_version());
5611
5612 let updated = InvertedIndex::load(dest_store, None, &LanceCache::no_cache()).await?;
5613 assert_eq!(updated.index_version(), format_version.index_version());
5614 assert_eq!(updated.partitions.len(), 2);
5615 for partition in &updated.partitions {
5616 assert_eq!(
5617 partition.inverted_list.posting_tail_codec(),
5618 posting_tail_codec
5619 );
5620 }
5621
5622 Ok(())
5623 }
5624
5625 #[tokio::test]
5626 async fn test_merge_segments_preserves_arrow_token_set_format() -> Result<()> {
5627 let src_dir = TempObjDir::default();
5628 let dest_dir = TempObjDir::default();
5629 let src_store = Arc::new(LanceIndexStore::new(
5630 ObjectStore::local().into(),
5631 src_dir.clone(),
5632 Arc::new(LanceCache::no_cache()),
5633 ));
5634 let dest_store = Arc::new(LanceIndexStore::new(
5635 ObjectStore::local().into(),
5636 dest_dir.clone(),
5637 Arc::new(LanceCache::no_cache()),
5638 ));
5639
5640 let index = write_single_partition_index(
5641 src_store,
5642 InvertedIndexParams::default(),
5643 TokenSetFormat::Arrow,
5644 "hello",
5645 100,
5646 )
5647 .await?;
5648 let created = InvertedIndex::merge_segments(
5649 &[index],
5650 empty_doc_stream(),
5651 dest_store.as_ref(),
5652 None,
5653 crate::progress::noop_progress(),
5654 )
5655 .await?;
5656
5657 assert_eq!(created.index_version, 0);
5658 let merged = InvertedIndex::load(dest_store, None, &LanceCache::no_cache()).await?;
5659 assert_eq!(merged.token_set_format, TokenSetFormat::Arrow);
5660
5661 let tokens = Arc::new(Tokens::new(vec!["hello".to_string()], DocType::Text));
5662 let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
5663 let prefilter = Arc::new(NoFilter);
5664 let metrics = Arc::new(NoOpMetricsCollector);
5665 let (row_ids, _) = merged
5666 .bm25_search(tokens, params, Operator::Or, prefilter, metrics, None)
5667 .await?;
5668 assert_eq!(row_ids, vec![100]);
5669
5670 Ok(())
5671 }
5672
5673 #[tokio::test]
5674 async fn test_merge_segments_uses_memory_limit_for_old_partitions() -> Result<()> {
5675 let src_dir_1 = TempObjDir::default();
5676 let src_dir_2 = TempObjDir::default();
5677 let dest_dir = TempObjDir::default();
5678 let src_store_1 = Arc::new(LanceIndexStore::new(
5679 ObjectStore::local().into(),
5680 src_dir_1.clone(),
5681 Arc::new(LanceCache::no_cache()),
5682 ));
5683 let src_store_2 = Arc::new(LanceIndexStore::new(
5684 ObjectStore::local().into(),
5685 src_dir_2.clone(),
5686 Arc::new(LanceCache::no_cache()),
5687 ));
5688 let dest_store = Arc::new(LanceIndexStore::new(
5689 ObjectStore::local().into(),
5690 dest_dir.clone(),
5691 Arc::new(LanceCache::no_cache()),
5692 ));
5693
5694 let params = InvertedIndexParams::default().memory_limit_mb(0);
5695 let first = write_single_partition_index(
5696 src_store_1,
5697 params.clone(),
5698 TokenSetFormat::default(),
5699 "alpha",
5700 100,
5701 )
5702 .await?;
5703 let second = write_single_partition_index(
5704 src_store_2,
5705 params,
5706 TokenSetFormat::default(),
5707 "beta",
5708 200,
5709 )
5710 .await?;
5711
5712 let mut builder =
5713 InvertedIndexBuilder::new(InvertedIndexParams::default().memory_limit_mb(0))
5714 .with_token_set_format(TokenSetFormat::default());
5715 builder
5716 .update_from_segments(
5717 empty_doc_stream(),
5718 dest_store.as_ref(),
5719 &[first, second],
5720 None,
5721 )
5722 .await?;
5723
5724 let merged = InvertedIndex::load(dest_store, None, &LanceCache::no_cache()).await?;
5725 assert_eq!(merged.partitions.len(), 2);
5726 let mut partition_ids = merged
5727 .partitions
5728 .iter()
5729 .map(|partition| partition.id())
5730 .collect::<Vec<_>>();
5731 partition_ids.sort_unstable();
5732 assert_eq!(partition_ids, vec![0, 1]);
5733
5734 Ok(())
5735 }
5736
5737 #[tokio::test]
5738 async fn test_modern_index_without_deleted_col_has_empty_bitmap() {
5739 let tmpdir = TempObjDir::default();
5743 let store = Arc::new(LanceIndexStore::new(
5744 ObjectStore::local().into(),
5745 tmpdir.clone(),
5746 Arc::new(LanceCache::no_cache()),
5747 ));
5748
5749 let mut builder = InnerBuilder::new(0, false, TokenSetFormat::default());
5750 builder.tokens.add("test".to_owned());
5751 builder.posting_lists.push(PostingListBuilder::new(false));
5752 builder.posting_lists[0].add(0, PositionRecorder::Count(1));
5753 builder.docs.append(100, 1);
5754 builder.write(store.as_ref()).await.unwrap();
5755
5756 let metadata = std::collections::HashMap::from_iter(vec![
5759 (
5760 "partitions".to_owned(),
5761 serde_json::to_string(&vec![0u64]).unwrap(),
5762 ),
5763 (
5764 "params".to_owned(),
5765 serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
5766 ),
5767 (
5768 TOKEN_SET_FORMAT_KEY.to_owned(),
5769 TokenSetFormat::default().to_string(),
5770 ),
5771 ]);
5772 let mut writer = store
5773 .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
5774 .await
5775 .unwrap();
5776 writer.finish_with_metadata(metadata).await.unwrap();
5777
5778 let index = InvertedIndex::load(store, None, &LanceCache::no_cache())
5779 .await
5780 .unwrap();
5781 assert!(
5782 index.deleted_fragments().is_empty(),
5783 "index without deleted_fragments column should have empty bitmap"
5784 );
5785 }
5786
5787 #[tokio::test]
5788 async fn flat_bm25_search_stream_with_metrics_records_elapsed_compute() {
5789 use crate::scalar::inverted::tokenizer::document_tokenizer::TextTokenizer;
5790 use arrow_array::{StringArray, UInt64Array};
5791 use lance_tokenizer::{SimpleTokenizer, TextAnalyzer};
5792
5793 let schema = Arc::new(Schema::new(vec![
5795 ROW_ID_FIELD.clone(),
5796 Field::new("text", DataType::Utf8, false),
5797 ]));
5798 let batch = RecordBatch::try_new(
5799 schema.clone(),
5800 vec![
5801 Arc::new(UInt64Array::from(vec![0u64, 1, 2, 3])),
5802 Arc::new(StringArray::from(vec![
5803 "the quick brown fox",
5804 "lazy dog sleeps",
5805 "the brown fox jumps over",
5806 "completely unrelated text",
5807 ])),
5808 ],
5809 )
5810 .unwrap();
5811
5812 let input: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
5813 schema.clone(),
5814 stream::iter(vec![Ok(batch)]),
5815 ));
5816
5817 let tokenizer: Box<dyn LanceTokenizer> = Box::new(TextTokenizer::new(
5818 TextAnalyzer::builder(SimpleTokenizer::default()).build(),
5819 ));
5820
5821 let elapsed_compute = Time::default();
5822 let result_stream = flat_bm25_search_stream_with_metrics(
5823 input,
5824 "text".to_string(),
5825 "fox".to_string(),
5826 tokenizer,
5827 None,
5828 100,
5829 Some(elapsed_compute.clone()),
5830 )
5831 .await
5832 .unwrap();
5833
5834 let batches: Vec<_> = result_stream.try_collect().await.unwrap();
5835 assert!(!batches.is_empty(), "expected at least one scored batch");
5836
5837 assert!(
5841 elapsed_compute.value() > 0,
5842 "elapsed_compute should have been populated; got 0"
5843 );
5844 }
5845}