1use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::combiner::MultiValueCombiner;
8use crate::query::ScoredPosition;
9use crate::query::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
10
11#[derive(Debug, Clone)]
13pub struct SparseVectorQuery {
14 pub field: Field,
16 pub vector: Vec<(u32, f32)>,
18 pub combiner: MultiValueCombiner,
20 pub heap_factor: f32,
23 pub weight_threshold: f32,
26 pub max_query_dims: Option<usize>,
29 pub pruning: Option<f32>,
33 pub min_query_dims: usize,
37 pub over_fetch_factor: f32,
39 pruned: Option<Vec<(u32, f32)>>,
41}
42
43impl std::fmt::Display for SparseVectorQuery {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 let dims = self.pruned_dims();
46 write!(f, "Sparse({}, dims={}", self.field.0, dims.len())?;
47 if self.heap_factor < 1.0 {
48 write!(f, ", heap={}", self.heap_factor)?;
49 }
50 if self.vector.len() != dims.len() {
51 write!(f, ", orig={}", self.vector.len())?;
52 }
53 write!(f, ")")
54 }
55}
56
57impl SparseVectorQuery {
58 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
65 let mut q = Self {
66 field,
67 vector,
68 combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
69 heap_factor: 1.0,
70 weight_threshold: 0.0,
71 max_query_dims: Some(crate::query::MAX_QUERY_TERMS),
72 pruning: None,
73 min_query_dims: 4,
74 over_fetch_factor: 2.0,
75 pruned: None,
76 };
77 q.pruned = Some(q.compute_pruned_vector());
78 q
79 }
80
81 pub(crate) fn pruned_dims(&self) -> &[(u32, f32)] {
83 self.pruned.as_deref().unwrap_or(&self.vector)
84 }
85
86 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
88 self.combiner = combiner;
89 self
90 }
91
92 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
97 self.over_fetch_factor = factor.max(1.0);
98 self
99 }
100
101 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
108 self.heap_factor = heap_factor.clamp(0.0, 1.0);
109 self
110 }
111
112 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
115 self.weight_threshold = threshold;
116 self.pruned = Some(self.compute_pruned_vector());
117 self
118 }
119
120 pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
122 self.max_query_dims = Some(max_dims);
123 self.pruned = Some(self.compute_pruned_vector());
124 self
125 }
126
127 pub fn with_pruning(mut self, fraction: f32) -> Self {
130 self.pruning = Some(fraction.clamp(0.0, 1.0));
131 self.pruned = Some(self.compute_pruned_vector());
132 self
133 }
134
135 pub fn with_min_query_dims(mut self, min_dims: usize) -> Self {
138 self.min_query_dims = min_dims;
139 self.pruned = Some(self.compute_pruned_vector());
140 self
141 }
142
143 fn compute_pruned_vector(&self) -> Vec<(u32, f32)> {
145 let original_len = self.vector.len();
146
147 let mut v: Vec<(u32, f32)> =
150 if self.weight_threshold > 0.0 && self.vector.len() > self.min_query_dims {
151 self.vector
152 .iter()
153 .copied()
154 .filter(|(_, w)| w.abs() >= self.weight_threshold)
155 .collect()
156 } else {
157 self.vector.clone()
158 };
159 let after_threshold = v.len();
160
161 let mut sorted_by_weight = false;
164 if let Some(fraction) = self.pruning
165 && fraction < 1.0
166 && v.len() > self.min_query_dims
167 {
168 v.sort_unstable_by(|a, b| {
169 b.1.abs()
170 .partial_cmp(&a.1.abs())
171 .unwrap_or(std::cmp::Ordering::Equal)
172 });
173 sorted_by_weight = true;
174 let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
175 v.truncate(keep);
176 }
177 let after_pruning = v.len();
178
179 if let Some(max_dims) = self.max_query_dims
181 && v.len() > max_dims
182 {
183 if !sorted_by_weight {
184 v.sort_unstable_by(|a, b| {
185 b.1.abs()
186 .partial_cmp(&a.1.abs())
187 .unwrap_or(std::cmp::Ordering::Equal)
188 });
189 }
190 v.truncate(max_dims);
191 }
192
193 if v.len() < original_len && log::log_enabled!(log::Level::Debug) {
194 let src: Vec<_> = self
195 .vector
196 .iter()
197 .map(|(d, w)| format!("({},{:.4})", d, w))
198 .collect();
199 let pruned_fmt: Vec<_> = v.iter().map(|(d, w)| format!("({},{:.4})", d, w)).collect();
200 log::debug!(
201 "[sparse query] field={}: pruned {}->{} dims \
202 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{}), \
203 source=[{}], pruned=[{}]",
204 self.field.0,
205 original_len,
206 v.len(),
207 original_len,
208 after_threshold,
209 after_threshold,
210 after_pruning,
211 after_pruning,
212 v.len(),
213 src.join(", "),
214 pruned_fmt.join(", "),
215 );
216 }
217
218 v
219 }
220
221 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
223 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
224 Self::new(field, vector)
225 }
226
227 #[cfg(feature = "native")]
239 pub fn from_text(
240 field: Field,
241 text: &str,
242 tokenizer_name: &str,
243 weighting: crate::structures::QueryWeighting,
244 sparse_index: Option<&crate::segment::SparseIndex>,
245 ) -> crate::Result<Self> {
246 use crate::structures::QueryWeighting;
247 use crate::tokenizer::tokenizer_cache;
248
249 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
250 let token_ids = tokenizer.tokenize_unique(text)?;
251
252 let weights: Vec<f32> = match weighting {
253 QueryWeighting::One => vec![1.0f32; token_ids.len()],
254 QueryWeighting::Idf => {
255 if let Some(index) = sparse_index {
256 index.idf_weights(&token_ids)
257 } else {
258 vec![1.0f32; token_ids.len()]
259 }
260 }
261 QueryWeighting::IdfFile => {
262 use crate::tokenizer::idf_weights_cache;
263 if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name, None) {
264 token_ids.iter().map(|&id| idf.get(id)).collect()
265 } else {
266 vec![1.0f32; token_ids.len()]
267 }
268 }
269 };
270
271 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
272 Ok(Self::new(field, vector))
273 }
274
275 #[cfg(feature = "native")]
287 pub fn from_text_with_stats(
288 field: Field,
289 text: &str,
290 tokenizer: &crate::tokenizer::HfTokenizer,
291 weighting: crate::structures::QueryWeighting,
292 global_stats: Option<&crate::query::GlobalStats>,
293 ) -> crate::Result<Self> {
294 use crate::structures::QueryWeighting;
295
296 let token_ids = tokenizer.tokenize_unique(text)?;
297
298 let weights: Vec<f32> = match weighting {
299 QueryWeighting::One => vec![1.0f32; token_ids.len()],
300 QueryWeighting::Idf => {
301 if let Some(stats) = global_stats {
302 stats
304 .sparse_idf_weights(field, &token_ids)
305 .into_iter()
306 .map(|w| w.max(0.0))
307 .collect()
308 } else {
309 vec![1.0f32; token_ids.len()]
310 }
311 }
312 QueryWeighting::IdfFile => {
313 vec![1.0f32; token_ids.len()]
316 }
317 };
318
319 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
320 Ok(Self::new(field, vector))
321 }
322
323 #[cfg(feature = "native")]
335 pub fn from_text_with_tokenizer_bytes(
336 field: Field,
337 text: &str,
338 tokenizer_bytes: &[u8],
339 weighting: crate::structures::QueryWeighting,
340 global_stats: Option<&crate::query::GlobalStats>,
341 ) -> crate::Result<Self> {
342 use crate::structures::QueryWeighting;
343 use crate::tokenizer::HfTokenizer;
344
345 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
346 let token_ids = tokenizer.tokenize_unique(text)?;
347
348 let weights: Vec<f32> = match weighting {
349 QueryWeighting::One => vec![1.0f32; token_ids.len()],
350 QueryWeighting::Idf => {
351 if let Some(stats) = global_stats {
352 stats
354 .sparse_idf_weights(field, &token_ids)
355 .into_iter()
356 .map(|w| w.max(0.0))
357 .collect()
358 } else {
359 vec![1.0f32; token_ids.len()]
360 }
361 }
362 QueryWeighting::IdfFile => {
363 vec![1.0f32; token_ids.len()]
366 }
367 };
368
369 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
370 Ok(Self::new(field, vector))
371 }
372}
373
374impl SparseVectorQuery {
375 fn sparse_infos(&self) -> Vec<crate::query::SparseTermQueryInfo> {
377 self.pruned_dims()
378 .iter()
379 .map(|&(dim_id, weight)| crate::query::SparseTermQueryInfo {
380 field: self.field,
381 dim_id,
382 weight,
383 heap_factor: self.heap_factor,
384 combiner: self.combiner,
385 over_fetch_factor: self.over_fetch_factor,
386 })
387 .collect()
388 }
389}
390
391impl Query for SparseVectorQuery {
392 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
393 let infos = self.sparse_infos();
394
395 Box::pin(async move {
396 if infos.is_empty() {
397 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>);
398 }
399
400 if let Some((raw, info)) =
402 crate::query::planner::build_sparse_bmp_results(&infos, reader, limit)
403 {
404 return Ok(crate::query::planner::combine_sparse_results(
405 raw,
406 info.combiner,
407 info.field,
408 limit,
409 ));
410 }
411
412 if let Some((executor, info)) =
414 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
415 {
416 let raw = executor.execute().await?;
417 return Ok(crate::query::planner::combine_sparse_results(
418 raw,
419 info.combiner,
420 info.field,
421 limit,
422 ));
423 }
424
425 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>)
426 })
427 }
428
429 #[cfg(feature = "sync")]
430 fn scorer_sync<'a>(
431 &self,
432 reader: &'a SegmentReader,
433 limit: usize,
434 ) -> crate::Result<Box<dyn Scorer + 'a>> {
435 let infos = self.sparse_infos();
436 if infos.is_empty() {
437 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>);
438 }
439
440 if let Some((raw, info)) =
442 crate::query::planner::build_sparse_bmp_results(&infos, reader, limit)
443 {
444 return Ok(crate::query::planner::combine_sparse_results(
445 raw,
446 info.combiner,
447 info.field,
448 limit,
449 ));
450 }
451
452 if let Some((executor, info)) =
454 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
455 {
456 let raw = executor.execute_sync()?;
457 return Ok(crate::query::planner::combine_sparse_results(
458 raw,
459 info.combiner,
460 info.field,
461 limit,
462 ));
463 }
464
465 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>)
466 }
467
468 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
469 Box::pin(async move { Ok(u32::MAX) })
470 }
471
472 fn decompose(&self) -> crate::query::QueryDecomposition {
473 let infos = self.sparse_infos();
474 if infos.is_empty() {
475 crate::query::QueryDecomposition::Opaque
476 } else {
477 crate::query::QueryDecomposition::SparseTerms(infos)
478 }
479 }
480}
481
482#[derive(Debug, Clone)]
490pub struct SparseTermQuery {
491 pub field: Field,
492 pub dim_id: u32,
493 pub weight: f32,
494 pub heap_factor: f32,
496 pub combiner: MultiValueCombiner,
498 pub over_fetch_factor: f32,
500}
501
502impl std::fmt::Display for SparseTermQuery {
503 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 write!(
505 f,
506 "SparseTerm({}, dim={}, w={:.3})",
507 self.field.0, self.dim_id, self.weight
508 )
509 }
510}
511
512impl SparseTermQuery {
513 pub fn new(field: Field, dim_id: u32, weight: f32) -> Self {
514 Self {
515 field,
516 dim_id,
517 weight,
518 heap_factor: 1.0,
519 combiner: MultiValueCombiner::default(),
520 over_fetch_factor: 2.0,
521 }
522 }
523
524 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
525 self.heap_factor = heap_factor;
526 self
527 }
528
529 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
530 self.combiner = combiner;
531 self
532 }
533
534 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
535 self.over_fetch_factor = factor.max(1.0);
536 self
537 }
538
539 fn bmp_fallback_scorer<'a>(
541 &self,
542 reader: &'a SegmentReader,
543 limit: usize,
544 ) -> crate::Result<Box<dyn Scorer + 'a>> {
545 if let Some(bmp) = reader.bmp_index(self.field) {
546 let results = crate::query::bmp::execute_bmp(
547 bmp,
548 &[(self.dim_id, self.weight)],
549 limit,
550 self.heap_factor,
551 0,
552 )?;
553 let combined = crate::segment::combine_ordinal_results(
554 results.into_iter().map(|r| (r.doc_id, r.ordinal, r.score)),
555 self.combiner,
556 limit,
557 );
558 return Ok(Box::new(
559 crate::query::planner::VectorTopKResultScorer::new(combined, self.field.0),
560 ));
561 }
562 Ok(Box::new(crate::query::EmptyScorer))
563 }
564
565 fn make_scorer<'a>(
568 &self,
569 reader: &'a SegmentReader,
570 ) -> crate::Result<Option<SparseTermScorer<'a>>> {
571 let si = match reader.sparse_index(self.field) {
572 Some(si) => si,
573 None => return Ok(None),
574 };
575 let (skip_start, skip_count, global_max, block_data_offset) =
576 match si.get_skip_range_full(self.dim_id) {
577 Some(v) => v,
578 None => return Ok(None),
579 };
580 let cursor = crate::query::TermCursor::sparse(
581 si,
582 self.weight,
583 skip_start,
584 skip_count,
585 global_max,
586 block_data_offset,
587 );
588 Ok(Some(SparseTermScorer {
589 cursor,
590 field_id: self.field.0,
591 }))
592 }
593}
594
595impl Query for SparseTermQuery {
596 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
597 let query = self.clone();
598 Box::pin(async move {
599 let mut scorer = match query.make_scorer(reader)? {
600 Some(s) => s,
601 None => return query.bmp_fallback_scorer(reader, limit),
602 };
603 scorer.cursor.ensure_block_loaded().await.ok();
604 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
605 })
606 }
607
608 #[cfg(feature = "sync")]
609 fn scorer_sync<'a>(
610 &self,
611 reader: &'a SegmentReader,
612 limit: usize,
613 ) -> crate::Result<Box<dyn Scorer + 'a>> {
614 let mut scorer = match self.make_scorer(reader)? {
615 Some(s) => s,
616 None => return self.bmp_fallback_scorer(reader, limit),
617 };
618 scorer.cursor.ensure_block_loaded_sync().ok();
619 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
620 }
621
622 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
623 let field = self.field;
624 let dim_id = self.dim_id;
625 Box::pin(async move {
626 let si = match reader.sparse_index(field) {
627 Some(si) => si,
628 None => return Ok(0),
629 };
630 match si.get_skip_range_full(dim_id) {
631 Some((_, skip_count, _, _)) => Ok((skip_count * 256) as u32),
632 None => Ok(0),
633 }
634 })
635 }
636
637 fn decompose(&self) -> crate::query::QueryDecomposition {
638 crate::query::QueryDecomposition::SparseTerms(vec![crate::query::SparseTermQueryInfo {
639 field: self.field,
640 dim_id: self.dim_id,
641 weight: self.weight,
642 heap_factor: self.heap_factor,
643 combiner: self.combiner,
644 over_fetch_factor: self.over_fetch_factor,
645 }])
646 }
647}
648
649struct SparseTermScorer<'a> {
654 cursor: crate::query::TermCursor<'a>,
655 field_id: u32,
656}
657
658impl crate::query::docset::DocSet for SparseTermScorer<'_> {
659 fn doc(&self) -> DocId {
660 let d = self.cursor.doc();
661 if d == u32::MAX { TERMINATED } else { d }
662 }
663
664 fn advance(&mut self) -> DocId {
665 match self.cursor.advance_sync() {
666 Ok(d) if d == u32::MAX => TERMINATED,
667 Ok(d) => d,
668 Err(_) => TERMINATED,
669 }
670 }
671
672 fn seek(&mut self, target: DocId) -> DocId {
673 match self.cursor.seek_sync(target) {
674 Ok(d) if d == u32::MAX => TERMINATED,
675 Ok(d) => d,
676 Err(_) => TERMINATED,
677 }
678 }
679
680 fn size_hint(&self) -> u32 {
681 0
682 }
683}
684
685impl Scorer for SparseTermScorer<'_> {
686 fn score(&self) -> Score {
687 self.cursor.score()
688 }
689
690 fn matched_positions(&self) -> Option<MatchedPositions> {
691 let ordinal = self.cursor.ordinal();
692 let score = self.cursor.score();
693 if score == 0.0 {
694 return None;
695 }
696 Some(vec![(
697 self.field_id,
698 vec![ScoredPosition::new(ordinal as u32, score)],
699 )])
700 }
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706 use crate::dsl::Field;
707
708 #[test]
709 fn test_sparse_vector_query_new() {
710 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
711 let query = SparseVectorQuery::new(Field(0), sparse.clone());
712
713 assert_eq!(query.field, Field(0));
714 assert_eq!(query.vector, sparse);
715 }
716
717 #[test]
718 fn test_sparse_vector_query_from_indices_weights() {
719 let query =
720 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
721
722 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
723 }
724}