1use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::combiner::MultiValueCombiner;
8use crate::query::ScoredPosition;
9use crate::query::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
10
11#[derive(Debug, Clone)]
13pub struct SparseVectorQuery {
14 pub field: Field,
16 pub vector: Vec<(u32, f32)>,
18 pub combiner: MultiValueCombiner,
20 pub heap_factor: f32,
23 pub weight_threshold: f32,
26 pub max_query_dims: Option<usize>,
29 pub pruning: Option<f32>,
33 pub over_fetch_factor: f32,
35 pruned: Option<Vec<(u32, f32)>>,
37}
38
39impl std::fmt::Display for SparseVectorQuery {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 let dims = self.pruned_dims();
42 write!(f, "Sparse({}, dims={}", self.field.0, dims.len())?;
43 if self.heap_factor < 1.0 {
44 write!(f, ", heap={}", self.heap_factor)?;
45 }
46 if self.vector.len() != dims.len() {
47 write!(f, ", orig={}", self.vector.len())?;
48 }
49 write!(f, ")")
50 }
51}
52
53impl SparseVectorQuery {
54 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
61 let mut q = Self {
62 field,
63 vector,
64 combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
65 heap_factor: 1.0,
66 weight_threshold: 0.0,
67 max_query_dims: Some(crate::query::MAX_QUERY_TERMS),
68 pruning: None,
69 over_fetch_factor: 2.0,
70 pruned: None,
71 };
72 q.pruned = Some(q.compute_pruned_vector());
73 q
74 }
75
76 pub(crate) fn pruned_dims(&self) -> &[(u32, f32)] {
78 self.pruned.as_deref().unwrap_or(&self.vector)
79 }
80
81 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
83 self.combiner = combiner;
84 self
85 }
86
87 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
92 self.over_fetch_factor = factor.max(1.0);
93 self
94 }
95
96 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
103 self.heap_factor = heap_factor.clamp(0.0, 1.0);
104 self
105 }
106
107 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
110 self.weight_threshold = threshold;
111 self.pruned = Some(self.compute_pruned_vector());
112 self
113 }
114
115 pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
117 self.max_query_dims = Some(max_dims);
118 self.pruned = Some(self.compute_pruned_vector());
119 self
120 }
121
122 pub fn with_pruning(mut self, fraction: f32) -> Self {
125 self.pruning = Some(fraction.clamp(0.0, 1.0));
126 self.pruned = Some(self.compute_pruned_vector());
127 self
128 }
129
130 fn compute_pruned_vector(&self) -> Vec<(u32, f32)> {
132 let original_len = self.vector.len();
133
134 let mut v: Vec<(u32, f32)> = if self.weight_threshold > 0.0 {
136 self.vector
137 .iter()
138 .copied()
139 .filter(|(_, w)| w.abs() >= self.weight_threshold)
140 .collect()
141 } else {
142 self.vector.clone()
143 };
144 let after_threshold = v.len();
145
146 let mut sorted_by_weight = false;
148 if let Some(fraction) = self.pruning
149 && fraction < 1.0
150 && v.len() > 1
151 {
152 v.sort_unstable_by(|a, b| {
153 b.1.abs()
154 .partial_cmp(&a.1.abs())
155 .unwrap_or(std::cmp::Ordering::Equal)
156 });
157 sorted_by_weight = true;
158 let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
159 v.truncate(keep);
160 }
161 let after_pruning = v.len();
162
163 if let Some(max_dims) = self.max_query_dims
165 && v.len() > max_dims
166 {
167 if !sorted_by_weight {
168 v.sort_unstable_by(|a, b| {
169 b.1.abs()
170 .partial_cmp(&a.1.abs())
171 .unwrap_or(std::cmp::Ordering::Equal)
172 });
173 }
174 v.truncate(max_dims);
175 }
176
177 if v.len() < original_len && log::log_enabled!(log::Level::Debug) {
178 let src: Vec<_> = self
179 .vector
180 .iter()
181 .map(|(d, w)| format!("({},{:.4})", d, w))
182 .collect();
183 let pruned_fmt: Vec<_> = v.iter().map(|(d, w)| format!("({},{:.4})", d, w)).collect();
184 log::debug!(
185 "[sparse query] field={}: pruned {}->{} dims \
186 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{}), \
187 source=[{}], pruned=[{}]",
188 self.field.0,
189 original_len,
190 v.len(),
191 original_len,
192 after_threshold,
193 after_threshold,
194 after_pruning,
195 after_pruning,
196 v.len(),
197 src.join(", "),
198 pruned_fmt.join(", "),
199 );
200 }
201
202 v
203 }
204
205 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
207 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
208 Self::new(field, vector)
209 }
210
211 #[cfg(feature = "native")]
223 pub fn from_text(
224 field: Field,
225 text: &str,
226 tokenizer_name: &str,
227 weighting: crate::structures::QueryWeighting,
228 sparse_index: Option<&crate::segment::SparseIndex>,
229 ) -> crate::Result<Self> {
230 use crate::structures::QueryWeighting;
231 use crate::tokenizer::tokenizer_cache;
232
233 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
234 let token_ids = tokenizer.tokenize_unique(text)?;
235
236 let weights: Vec<f32> = match weighting {
237 QueryWeighting::One => vec![1.0f32; token_ids.len()],
238 QueryWeighting::Idf => {
239 if let Some(index) = sparse_index {
240 index.idf_weights(&token_ids)
241 } else {
242 vec![1.0f32; token_ids.len()]
243 }
244 }
245 QueryWeighting::IdfFile => {
246 use crate::tokenizer::idf_weights_cache;
247 if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name, None) {
248 token_ids.iter().map(|&id| idf.get(id)).collect()
249 } else {
250 vec![1.0f32; token_ids.len()]
251 }
252 }
253 };
254
255 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
256 Ok(Self::new(field, vector))
257 }
258
259 #[cfg(feature = "native")]
271 pub fn from_text_with_stats(
272 field: Field,
273 text: &str,
274 tokenizer: &crate::tokenizer::HfTokenizer,
275 weighting: crate::structures::QueryWeighting,
276 global_stats: Option<&crate::query::GlobalStats>,
277 ) -> crate::Result<Self> {
278 use crate::structures::QueryWeighting;
279
280 let token_ids = tokenizer.tokenize_unique(text)?;
281
282 let weights: Vec<f32> = match weighting {
283 QueryWeighting::One => vec![1.0f32; token_ids.len()],
284 QueryWeighting::Idf => {
285 if let Some(stats) = global_stats {
286 stats
288 .sparse_idf_weights(field, &token_ids)
289 .into_iter()
290 .map(|w| w.max(0.0))
291 .collect()
292 } else {
293 vec![1.0f32; token_ids.len()]
294 }
295 }
296 QueryWeighting::IdfFile => {
297 vec![1.0f32; token_ids.len()]
300 }
301 };
302
303 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
304 Ok(Self::new(field, vector))
305 }
306
307 #[cfg(feature = "native")]
319 pub fn from_text_with_tokenizer_bytes(
320 field: Field,
321 text: &str,
322 tokenizer_bytes: &[u8],
323 weighting: crate::structures::QueryWeighting,
324 global_stats: Option<&crate::query::GlobalStats>,
325 ) -> crate::Result<Self> {
326 use crate::structures::QueryWeighting;
327 use crate::tokenizer::HfTokenizer;
328
329 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
330 let token_ids = tokenizer.tokenize_unique(text)?;
331
332 let weights: Vec<f32> = match weighting {
333 QueryWeighting::One => vec![1.0f32; token_ids.len()],
334 QueryWeighting::Idf => {
335 if let Some(stats) = global_stats {
336 stats
338 .sparse_idf_weights(field, &token_ids)
339 .into_iter()
340 .map(|w| w.max(0.0))
341 .collect()
342 } else {
343 vec![1.0f32; token_ids.len()]
344 }
345 }
346 QueryWeighting::IdfFile => {
347 vec![1.0f32; token_ids.len()]
350 }
351 };
352
353 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
354 Ok(Self::new(field, vector))
355 }
356}
357
358impl SparseVectorQuery {
359 fn sparse_infos(&self) -> Vec<crate::query::SparseTermQueryInfo> {
361 self.pruned_dims()
362 .iter()
363 .map(|&(dim_id, weight)| crate::query::SparseTermQueryInfo {
364 field: self.field,
365 dim_id,
366 weight,
367 heap_factor: self.heap_factor,
368 combiner: self.combiner,
369 over_fetch_factor: self.over_fetch_factor,
370 })
371 .collect()
372 }
373}
374
375impl Query for SparseVectorQuery {
376 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
377 let infos = self.sparse_infos();
378
379 Box::pin(async move {
380 if infos.is_empty() {
381 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>);
382 }
383
384 if let Some((raw, info)) =
386 crate::query::planner::build_sparse_bmp_results(&infos, reader, limit)
387 {
388 return Ok(crate::query::planner::combine_sparse_results(
389 raw,
390 info.combiner,
391 info.field,
392 limit,
393 ));
394 }
395
396 if let Some((executor, info)) =
398 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
399 {
400 let raw = executor.execute().await?;
401 return Ok(crate::query::planner::combine_sparse_results(
402 raw,
403 info.combiner,
404 info.field,
405 limit,
406 ));
407 }
408
409 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>)
410 })
411 }
412
413 #[cfg(feature = "sync")]
414 fn scorer_sync<'a>(
415 &self,
416 reader: &'a SegmentReader,
417 limit: usize,
418 ) -> crate::Result<Box<dyn Scorer + 'a>> {
419 let infos = self.sparse_infos();
420 if infos.is_empty() {
421 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>);
422 }
423
424 if let Some((raw, info)) =
426 crate::query::planner::build_sparse_bmp_results(&infos, reader, limit)
427 {
428 return Ok(crate::query::planner::combine_sparse_results(
429 raw,
430 info.combiner,
431 info.field,
432 limit,
433 ));
434 }
435
436 if let Some((executor, info)) =
438 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
439 {
440 let raw = executor.execute_sync()?;
441 return Ok(crate::query::planner::combine_sparse_results(
442 raw,
443 info.combiner,
444 info.field,
445 limit,
446 ));
447 }
448
449 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>)
450 }
451
452 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
453 Box::pin(async move { Ok(u32::MAX) })
454 }
455
456 fn decompose(&self) -> crate::query::QueryDecomposition {
457 let infos = self.sparse_infos();
458 if infos.is_empty() {
459 crate::query::QueryDecomposition::Opaque
460 } else {
461 crate::query::QueryDecomposition::SparseTerms(infos)
462 }
463 }
464}
465
466#[derive(Debug, Clone)]
474pub struct SparseTermQuery {
475 pub field: Field,
476 pub dim_id: u32,
477 pub weight: f32,
478 pub heap_factor: f32,
480 pub combiner: MultiValueCombiner,
482 pub over_fetch_factor: f32,
484}
485
486impl std::fmt::Display for SparseTermQuery {
487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
488 write!(
489 f,
490 "SparseTerm({}, dim={}, w={:.3})",
491 self.field.0, self.dim_id, self.weight
492 )
493 }
494}
495
496impl SparseTermQuery {
497 pub fn new(field: Field, dim_id: u32, weight: f32) -> Self {
498 Self {
499 field,
500 dim_id,
501 weight,
502 heap_factor: 1.0,
503 combiner: MultiValueCombiner::default(),
504 over_fetch_factor: 2.0,
505 }
506 }
507
508 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
509 self.heap_factor = heap_factor;
510 self
511 }
512
513 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
514 self.combiner = combiner;
515 self
516 }
517
518 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
519 self.over_fetch_factor = factor.max(1.0);
520 self
521 }
522
523 fn bmp_fallback_scorer<'a>(
525 &self,
526 reader: &'a SegmentReader,
527 limit: usize,
528 ) -> crate::Result<Box<dyn Scorer + 'a>> {
529 if let Some(bmp) = reader.bmp_index(self.field) {
530 let results = crate::query::bmp::execute_bmp(
531 bmp,
532 &[(self.dim_id, self.weight)],
533 limit,
534 self.heap_factor,
535 0,
536 )?;
537 let combined = crate::segment::combine_ordinal_results(
538 results.into_iter().map(|r| (r.doc_id, r.ordinal, r.score)),
539 self.combiner,
540 limit,
541 );
542 return Ok(Box::new(
543 crate::query::planner::VectorTopKResultScorer::new(combined, self.field.0),
544 ));
545 }
546 Ok(Box::new(crate::query::EmptyScorer))
547 }
548
549 fn make_scorer<'a>(
552 &self,
553 reader: &'a SegmentReader,
554 ) -> crate::Result<Option<SparseTermScorer<'a>>> {
555 let si = match reader.sparse_index(self.field) {
556 Some(si) => si,
557 None => return Ok(None),
558 };
559 let (skip_start, skip_count, global_max, block_data_offset) =
560 match si.get_skip_range_full(self.dim_id) {
561 Some(v) => v,
562 None => return Ok(None),
563 };
564 let cursor = crate::query::TermCursor::sparse(
565 si,
566 self.weight,
567 skip_start,
568 skip_count,
569 global_max,
570 block_data_offset,
571 );
572 Ok(Some(SparseTermScorer {
573 cursor,
574 field_id: self.field.0,
575 }))
576 }
577}
578
579impl Query for SparseTermQuery {
580 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
581 let query = self.clone();
582 Box::pin(async move {
583 let mut scorer = match query.make_scorer(reader)? {
584 Some(s) => s,
585 None => return query.bmp_fallback_scorer(reader, limit),
586 };
587 scorer.cursor.ensure_block_loaded().await.ok();
588 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
589 })
590 }
591
592 #[cfg(feature = "sync")]
593 fn scorer_sync<'a>(
594 &self,
595 reader: &'a SegmentReader,
596 limit: usize,
597 ) -> crate::Result<Box<dyn Scorer + 'a>> {
598 let mut scorer = match self.make_scorer(reader)? {
599 Some(s) => s,
600 None => return self.bmp_fallback_scorer(reader, limit),
601 };
602 scorer.cursor.ensure_block_loaded_sync().ok();
603 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
604 }
605
606 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
607 let field = self.field;
608 let dim_id = self.dim_id;
609 Box::pin(async move {
610 let si = match reader.sparse_index(field) {
611 Some(si) => si,
612 None => return Ok(0),
613 };
614 match si.get_skip_range_full(dim_id) {
615 Some((_, skip_count, _, _)) => Ok((skip_count * 256) as u32),
616 None => Ok(0),
617 }
618 })
619 }
620
621 fn decompose(&self) -> crate::query::QueryDecomposition {
622 crate::query::QueryDecomposition::SparseTerms(vec![crate::query::SparseTermQueryInfo {
623 field: self.field,
624 dim_id: self.dim_id,
625 weight: self.weight,
626 heap_factor: self.heap_factor,
627 combiner: self.combiner,
628 over_fetch_factor: self.over_fetch_factor,
629 }])
630 }
631}
632
633struct SparseTermScorer<'a> {
638 cursor: crate::query::TermCursor<'a>,
639 field_id: u32,
640}
641
642impl crate::query::docset::DocSet for SparseTermScorer<'_> {
643 fn doc(&self) -> DocId {
644 let d = self.cursor.doc();
645 if d == u32::MAX { TERMINATED } else { d }
646 }
647
648 fn advance(&mut self) -> DocId {
649 match self.cursor.advance_sync() {
650 Ok(d) if d == u32::MAX => TERMINATED,
651 Ok(d) => d,
652 Err(_) => TERMINATED,
653 }
654 }
655
656 fn seek(&mut self, target: DocId) -> DocId {
657 match self.cursor.seek_sync(target) {
658 Ok(d) if d == u32::MAX => TERMINATED,
659 Ok(d) => d,
660 Err(_) => TERMINATED,
661 }
662 }
663
664 fn size_hint(&self) -> u32 {
665 0
666 }
667}
668
669impl Scorer for SparseTermScorer<'_> {
670 fn score(&self) -> Score {
671 self.cursor.score()
672 }
673
674 fn matched_positions(&self) -> Option<MatchedPositions> {
675 let ordinal = self.cursor.ordinal();
676 let score = self.cursor.score();
677 if score == 0.0 {
678 return None;
679 }
680 Some(vec![(
681 self.field_id,
682 vec![ScoredPosition::new(ordinal as u32, score)],
683 )])
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use crate::dsl::Field;
691
692 #[test]
693 fn test_sparse_vector_query_new() {
694 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
695 let query = SparseVectorQuery::new(Field(0), sparse.clone());
696
697 assert_eq!(query.field, Field(0));
698 assert_eq!(query.vector, sparse);
699 }
700
701 #[test]
702 fn test_sparse_vector_query_from_indices_weights() {
703 let query =
704 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
705
706 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
707 }
708}