1use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::combiner::MultiValueCombiner;
8use crate::query::ScoredPosition;
9use crate::query::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
10
11#[derive(Debug, Clone)]
13pub struct SparseVectorQuery {
14 pub field: Field,
16 pub vector: Vec<(u32, f32)>,
18 pub combiner: MultiValueCombiner,
20 pub heap_factor: f32,
23 pub weight_threshold: f32,
26 pub max_query_dims: Option<usize>,
29 pub pruning: Option<f32>,
33 pub min_query_dims: usize,
37 pub over_fetch_factor: f32,
39 pub max_superblocks: usize,
41 pruned: Option<Vec<(u32, f32)>>,
43}
44
45impl std::fmt::Display for SparseVectorQuery {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 let dims = self.pruned_dims();
48 write!(f, "Sparse({}, dims={}", self.field.0, dims.len())?;
49 if self.heap_factor < 1.0 {
50 write!(f, ", heap={}", self.heap_factor)?;
51 }
52 if self.vector.len() != dims.len() {
53 write!(f, ", orig={}", self.vector.len())?;
54 }
55 write!(f, ")")
56 }
57}
58
59impl SparseVectorQuery {
60 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
67 let mut q = Self {
68 field,
69 vector,
70 combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
71 heap_factor: 1.0,
72 weight_threshold: 0.0,
73 max_query_dims: Some(crate::query::MAX_QUERY_TERMS),
74 pruning: None,
75 min_query_dims: 4,
76 over_fetch_factor: 2.0,
77 max_superblocks: 0,
78 pruned: None,
79 };
80 q.pruned = Some(q.compute_pruned_vector());
81 q
82 }
83
84 pub(crate) fn pruned_dims(&self) -> &[(u32, f32)] {
86 self.pruned.as_deref().unwrap_or(&self.vector)
87 }
88
89 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
91 self.combiner = combiner;
92 self
93 }
94
95 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
100 self.over_fetch_factor = factor.max(1.0);
101 self
102 }
103
104 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
111 self.heap_factor = heap_factor.clamp(0.0, 1.0);
112 self
113 }
114
115 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
118 self.weight_threshold = threshold;
119 self.pruned = Some(self.compute_pruned_vector());
120 self
121 }
122
123 pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
125 self.max_query_dims = Some(max_dims);
126 self.pruned = Some(self.compute_pruned_vector());
127 self
128 }
129
130 pub fn with_pruning(mut self, fraction: f32) -> Self {
133 self.pruning = Some(fraction.clamp(0.0, 1.0));
134 self.pruned = Some(self.compute_pruned_vector());
135 self
136 }
137
138 pub fn with_min_query_dims(mut self, min_dims: usize) -> Self {
141 self.min_query_dims = min_dims;
142 self.pruned = Some(self.compute_pruned_vector());
143 self
144 }
145
146 fn compute_pruned_vector(&self) -> Vec<(u32, f32)> {
148 let original_len = self.vector.len();
149
150 let mut v: Vec<(u32, f32)> =
153 if self.weight_threshold > 0.0 && self.vector.len() > self.min_query_dims {
154 self.vector
155 .iter()
156 .copied()
157 .filter(|(_, w)| w.abs() >= self.weight_threshold)
158 .collect()
159 } else {
160 self.vector.clone()
161 };
162 let after_threshold = v.len();
163
164 let mut sorted_by_weight = false;
167 if let Some(fraction) = self.pruning
168 && fraction < 1.0
169 && v.len() > self.min_query_dims
170 {
171 v.sort_unstable_by(|a, b| {
172 b.1.abs()
173 .partial_cmp(&a.1.abs())
174 .unwrap_or(std::cmp::Ordering::Equal)
175 });
176 sorted_by_weight = true;
177 let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
178 v.truncate(keep);
179 }
180 let after_pruning = v.len();
181
182 if let Some(max_dims) = self.max_query_dims
184 && v.len() > max_dims
185 {
186 if !sorted_by_weight {
187 v.sort_unstable_by(|a, b| {
188 b.1.abs()
189 .partial_cmp(&a.1.abs())
190 .unwrap_or(std::cmp::Ordering::Equal)
191 });
192 }
193 v.truncate(max_dims);
194 }
195
196 if v.len() < original_len && log::log_enabled!(log::Level::Debug) {
197 let src: Vec<_> = self
198 .vector
199 .iter()
200 .map(|(d, w)| format!("({},{:.4})", d, w))
201 .collect();
202 let pruned_fmt: Vec<_> = v.iter().map(|(d, w)| format!("({},{:.4})", d, w)).collect();
203 log::debug!(
204 "[sparse query] field={}: pruned {}->{} dims \
205 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{}), \
206 source=[{}], pruned=[{}]",
207 self.field.0,
208 original_len,
209 v.len(),
210 original_len,
211 after_threshold,
212 after_threshold,
213 after_pruning,
214 after_pruning,
215 v.len(),
216 src.join(", "),
217 pruned_fmt.join(", "),
218 );
219 }
220
221 v
222 }
223
224 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
226 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
227 Self::new(field, vector)
228 }
229
230 #[cfg(feature = "native")]
242 pub fn from_text(
243 field: Field,
244 text: &str,
245 tokenizer_name: &str,
246 weighting: crate::structures::QueryWeighting,
247 sparse_index: Option<&crate::segment::SparseIndex>,
248 ) -> crate::Result<Self> {
249 use crate::structures::QueryWeighting;
250 use crate::tokenizer::tokenizer_cache;
251
252 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
253 let token_ids = tokenizer.tokenize_unique(text)?;
254
255 let weights: Vec<f32> = match weighting {
256 QueryWeighting::One => vec![1.0f32; token_ids.len()],
257 QueryWeighting::Idf => {
258 if let Some(index) = sparse_index {
259 index.idf_weights(&token_ids)
260 } else {
261 vec![1.0f32; token_ids.len()]
262 }
263 }
264 QueryWeighting::IdfFile => {
265 use crate::tokenizer::idf_weights_cache;
266 if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name, None) {
267 token_ids.iter().map(|&id| idf.get(id)).collect()
268 } else {
269 vec![1.0f32; token_ids.len()]
270 }
271 }
272 };
273
274 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
275 Ok(Self::new(field, vector))
276 }
277
278 #[cfg(feature = "native")]
290 pub fn from_text_with_stats(
291 field: Field,
292 text: &str,
293 tokenizer: &crate::tokenizer::HfTokenizer,
294 weighting: crate::structures::QueryWeighting,
295 global_stats: Option<&crate::query::GlobalStats>,
296 ) -> crate::Result<Self> {
297 use crate::structures::QueryWeighting;
298
299 let token_ids = tokenizer.tokenize_unique(text)?;
300
301 let weights: Vec<f32> = match weighting {
302 QueryWeighting::One => vec![1.0f32; token_ids.len()],
303 QueryWeighting::Idf => {
304 if let Some(stats) = global_stats {
305 stats
307 .sparse_idf_weights(field, &token_ids)
308 .into_iter()
309 .map(|w| w.max(0.0))
310 .collect()
311 } else {
312 vec![1.0f32; token_ids.len()]
313 }
314 }
315 QueryWeighting::IdfFile => {
316 vec![1.0f32; token_ids.len()]
319 }
320 };
321
322 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
323 Ok(Self::new(field, vector))
324 }
325
326 #[cfg(feature = "native")]
338 pub fn from_text_with_tokenizer_bytes(
339 field: Field,
340 text: &str,
341 tokenizer_bytes: &[u8],
342 weighting: crate::structures::QueryWeighting,
343 global_stats: Option<&crate::query::GlobalStats>,
344 ) -> crate::Result<Self> {
345 use crate::structures::QueryWeighting;
346 use crate::tokenizer::HfTokenizer;
347
348 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
349 let token_ids = tokenizer.tokenize_unique(text)?;
350
351 let weights: Vec<f32> = match weighting {
352 QueryWeighting::One => vec![1.0f32; token_ids.len()],
353 QueryWeighting::Idf => {
354 if let Some(stats) = global_stats {
355 stats
357 .sparse_idf_weights(field, &token_ids)
358 .into_iter()
359 .map(|w| w.max(0.0))
360 .collect()
361 } else {
362 vec![1.0f32; token_ids.len()]
363 }
364 }
365 QueryWeighting::IdfFile => {
366 vec![1.0f32; token_ids.len()]
369 }
370 };
371
372 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
373 Ok(Self::new(field, vector))
374 }
375}
376
377impl SparseVectorQuery {
378 fn sparse_infos(&self) -> Vec<crate::query::SparseTermQueryInfo> {
380 self.pruned_dims()
381 .iter()
382 .map(|&(dim_id, weight)| crate::query::SparseTermQueryInfo {
383 field: self.field,
384 dim_id,
385 weight,
386 heap_factor: self.heap_factor,
387 combiner: self.combiner,
388 over_fetch_factor: self.over_fetch_factor,
389 max_superblocks: self.max_superblocks,
390 })
391 .collect()
392 }
393}
394
395impl Query for SparseVectorQuery {
396 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
397 let infos = self.sparse_infos();
398
399 Box::pin(async move {
400 if infos.is_empty() {
401 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>);
402 }
403
404 if let Some((raw, info)) =
406 crate::query::planner::build_sparse_bmp_results(&infos, reader, limit)
407 {
408 return Ok(crate::query::planner::combine_sparse_results(
409 raw,
410 info.combiner,
411 info.field,
412 limit,
413 ));
414 }
415
416 if let Some((executor, info)) =
418 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
419 {
420 let raw = executor.execute().await?;
421 return Ok(crate::query::planner::combine_sparse_results(
422 raw,
423 info.combiner,
424 info.field,
425 limit,
426 ));
427 }
428
429 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>)
430 })
431 }
432
433 #[cfg(feature = "sync")]
434 fn scorer_sync<'a>(
435 &self,
436 reader: &'a SegmentReader,
437 limit: usize,
438 ) -> crate::Result<Box<dyn Scorer + 'a>> {
439 let infos = self.sparse_infos();
440 if infos.is_empty() {
441 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>);
442 }
443
444 if let Some((raw, info)) =
446 crate::query::planner::build_sparse_bmp_results(&infos, reader, limit)
447 {
448 return Ok(crate::query::planner::combine_sparse_results(
449 raw,
450 info.combiner,
451 info.field,
452 limit,
453 ));
454 }
455
456 if let Some((executor, info)) =
458 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
459 {
460 let raw = executor.execute_sync()?;
461 return Ok(crate::query::planner::combine_sparse_results(
462 raw,
463 info.combiner,
464 info.field,
465 limit,
466 ));
467 }
468
469 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>)
470 }
471
472 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
473 Box::pin(async move { Ok(u32::MAX) })
474 }
475
476 fn decompose(&self) -> crate::query::QueryDecomposition {
477 let infos = self.sparse_infos();
478 if infos.is_empty() {
479 crate::query::QueryDecomposition::Opaque
480 } else {
481 crate::query::QueryDecomposition::SparseTerms(infos)
482 }
483 }
484}
485
486#[derive(Debug, Clone)]
494pub struct SparseTermQuery {
495 pub field: Field,
496 pub dim_id: u32,
497 pub weight: f32,
498 pub heap_factor: f32,
500 pub combiner: MultiValueCombiner,
502 pub over_fetch_factor: f32,
504}
505
506impl std::fmt::Display for SparseTermQuery {
507 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508 write!(
509 f,
510 "SparseTerm({}, dim={}, w={:.3})",
511 self.field.0, self.dim_id, self.weight
512 )
513 }
514}
515
516impl SparseTermQuery {
517 pub fn new(field: Field, dim_id: u32, weight: f32) -> Self {
518 Self {
519 field,
520 dim_id,
521 weight,
522 heap_factor: 1.0,
523 combiner: MultiValueCombiner::default(),
524 over_fetch_factor: 2.0,
525 }
526 }
527
528 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
529 self.heap_factor = heap_factor;
530 self
531 }
532
533 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
534 self.combiner = combiner;
535 self
536 }
537
538 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
539 self.over_fetch_factor = factor.max(1.0);
540 self
541 }
542
543 fn bmp_fallback_scorer<'a>(
545 &self,
546 reader: &'a SegmentReader,
547 limit: usize,
548 ) -> crate::Result<Box<dyn Scorer + 'a>> {
549 if let Some(bmp) = reader.bmp_index(self.field) {
550 let results = crate::query::bmp::execute_bmp(
551 bmp,
552 &[(self.dim_id, self.weight)],
553 limit,
554 self.heap_factor,
555 0,
556 )?;
557 let combined = crate::segment::combine_ordinal_results(
558 results.into_iter().map(|r| (r.doc_id, r.ordinal, r.score)),
559 self.combiner,
560 limit,
561 );
562 return Ok(Box::new(
563 crate::query::planner::VectorTopKResultScorer::new(combined, self.field.0),
564 ));
565 }
566 Ok(Box::new(crate::query::EmptyScorer))
567 }
568
569 fn make_scorer<'a>(
572 &self,
573 reader: &'a SegmentReader,
574 ) -> crate::Result<Option<SparseTermScorer<'a>>> {
575 let si = match reader.sparse_index(self.field) {
576 Some(si) => si,
577 None => return Ok(None),
578 };
579 let (skip_start, skip_count, global_max, block_data_offset) =
580 match si.get_skip_range_full(self.dim_id) {
581 Some(v) => v,
582 None => return Ok(None),
583 };
584 let cursor = crate::query::TermCursor::sparse(
585 si,
586 self.weight,
587 skip_start,
588 skip_count,
589 global_max,
590 block_data_offset,
591 );
592 Ok(Some(SparseTermScorer {
593 cursor,
594 field_id: self.field.0,
595 }))
596 }
597}
598
599impl Query for SparseTermQuery {
600 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
601 let query = self.clone();
602 Box::pin(async move {
603 let mut scorer = match query.make_scorer(reader)? {
604 Some(s) => s,
605 None => return query.bmp_fallback_scorer(reader, limit),
606 };
607 scorer.cursor.ensure_block_loaded().await.ok();
608 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
609 })
610 }
611
612 #[cfg(feature = "sync")]
613 fn scorer_sync<'a>(
614 &self,
615 reader: &'a SegmentReader,
616 limit: usize,
617 ) -> crate::Result<Box<dyn Scorer + 'a>> {
618 let mut scorer = match self.make_scorer(reader)? {
619 Some(s) => s,
620 None => return self.bmp_fallback_scorer(reader, limit),
621 };
622 scorer.cursor.ensure_block_loaded_sync().ok();
623 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
624 }
625
626 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
627 let field = self.field;
628 let dim_id = self.dim_id;
629 Box::pin(async move {
630 let si = match reader.sparse_index(field) {
631 Some(si) => si,
632 None => return Ok(0),
633 };
634 match si.get_skip_range_full(dim_id) {
635 Some((_, skip_count, _, _)) => Ok((skip_count * 256) as u32),
636 None => Ok(0),
637 }
638 })
639 }
640
641 fn decompose(&self) -> crate::query::QueryDecomposition {
642 crate::query::QueryDecomposition::SparseTerms(vec![crate::query::SparseTermQueryInfo {
643 field: self.field,
644 dim_id: self.dim_id,
645 weight: self.weight,
646 heap_factor: self.heap_factor,
647 combiner: self.combiner,
648 over_fetch_factor: self.over_fetch_factor,
649 max_superblocks: 0,
650 }])
651 }
652}
653
654struct SparseTermScorer<'a> {
659 cursor: crate::query::TermCursor<'a>,
660 field_id: u32,
661}
662
663impl crate::query::docset::DocSet for SparseTermScorer<'_> {
664 fn doc(&self) -> DocId {
665 let d = self.cursor.doc();
666 if d == u32::MAX { TERMINATED } else { d }
667 }
668
669 fn advance(&mut self) -> DocId {
670 match self.cursor.advance_sync() {
671 Ok(d) if d == u32::MAX => TERMINATED,
672 Ok(d) => d,
673 Err(_) => TERMINATED,
674 }
675 }
676
677 fn seek(&mut self, target: DocId) -> DocId {
678 match self.cursor.seek_sync(target) {
679 Ok(d) if d == u32::MAX => TERMINATED,
680 Ok(d) => d,
681 Err(_) => TERMINATED,
682 }
683 }
684
685 fn size_hint(&self) -> u32 {
686 0
687 }
688}
689
690impl Scorer for SparseTermScorer<'_> {
691 fn score(&self) -> Score {
692 self.cursor.score()
693 }
694
695 fn matched_positions(&self) -> Option<MatchedPositions> {
696 let ordinal = self.cursor.ordinal();
697 let score = self.cursor.score();
698 if score == 0.0 {
699 return None;
700 }
701 Some(vec![(
702 self.field_id,
703 vec![ScoredPosition::new(ordinal as u32, score)],
704 )])
705 }
706}
707
708#[cfg(test)]
709mod tests {
710 use super::*;
711 use crate::dsl::Field;
712
713 #[test]
714 fn test_sparse_vector_query_new() {
715 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
716 let query = SparseVectorQuery::new(Field(0), sparse.clone());
717
718 assert_eq!(query.field, Field(0));
719 assert_eq!(query.vector, sparse);
720 }
721
722 #[test]
723 fn test_sparse_vector_query_from_indices_weights() {
724 let query =
725 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
726
727 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
728 }
729}