1use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::combiner::MultiValueCombiner;
8use crate::query::ScoredPosition;
9use crate::query::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
10
11#[derive(Debug, Clone)]
13pub struct SparseVectorQuery {
14 pub field: Field,
16 pub vector: Vec<(u32, f32)>,
18 pub combiner: MultiValueCombiner,
20 pub heap_factor: f32,
23 pub weight_threshold: f32,
26 pub max_query_dims: Option<usize>,
29 pub pruning: Option<f32>,
33 pub over_fetch_factor: f32,
35 pruned: Option<Vec<(u32, f32)>>,
37}
38
39impl std::fmt::Display for SparseVectorQuery {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 let dims = self.pruned_dims();
42 write!(f, "Sparse({}, dims={}", self.field.0, dims.len())?;
43 if self.heap_factor < 1.0 {
44 write!(f, ", heap={}", self.heap_factor)?;
45 }
46 if self.vector.len() != dims.len() {
47 write!(f, ", orig={}", self.vector.len())?;
48 }
49 write!(f, ")")
50 }
51}
52
53impl SparseVectorQuery {
54 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
61 Self {
62 field,
63 vector,
64 combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
65 heap_factor: 1.0,
66 weight_threshold: 0.0,
67 max_query_dims: None,
68 pruning: None,
69 over_fetch_factor: 2.0,
70 pruned: None,
71 }
72 }
73
74 pub(crate) fn pruned_dims(&self) -> &[(u32, f32)] {
76 self.pruned.as_deref().unwrap_or(&self.vector)
77 }
78
79 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
81 self.combiner = combiner;
82 self
83 }
84
85 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
90 self.over_fetch_factor = factor.max(1.0);
91 self
92 }
93
94 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
101 self.heap_factor = heap_factor.clamp(0.0, 1.0);
102 self
103 }
104
105 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
108 self.weight_threshold = threshold;
109 self.pruned = Some(self.compute_pruned_vector());
110 self
111 }
112
113 pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
115 self.max_query_dims = Some(max_dims);
116 self.pruned = Some(self.compute_pruned_vector());
117 self
118 }
119
120 pub fn with_pruning(mut self, fraction: f32) -> Self {
123 self.pruning = Some(fraction.clamp(0.0, 1.0));
124 self.pruned = Some(self.compute_pruned_vector());
125 self
126 }
127
128 fn compute_pruned_vector(&self) -> Vec<(u32, f32)> {
130 let original_len = self.vector.len();
131
132 let mut v: Vec<(u32, f32)> = if self.weight_threshold > 0.0 {
134 self.vector
135 .iter()
136 .copied()
137 .filter(|(_, w)| w.abs() >= self.weight_threshold)
138 .collect()
139 } else {
140 self.vector.clone()
141 };
142 let after_threshold = v.len();
143
144 let mut sorted_by_weight = false;
146 if let Some(fraction) = self.pruning
147 && fraction < 1.0
148 && v.len() > 1
149 {
150 v.sort_unstable_by(|a, b| {
151 b.1.abs()
152 .partial_cmp(&a.1.abs())
153 .unwrap_or(std::cmp::Ordering::Equal)
154 });
155 sorted_by_weight = true;
156 let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
157 v.truncate(keep);
158 }
159 let after_pruning = v.len();
160
161 if let Some(max_dims) = self.max_query_dims
163 && v.len() > max_dims
164 {
165 if !sorted_by_weight {
166 v.sort_unstable_by(|a, b| {
167 b.1.abs()
168 .partial_cmp(&a.1.abs())
169 .unwrap_or(std::cmp::Ordering::Equal)
170 });
171 }
172 v.truncate(max_dims);
173 }
174
175 if v.len() < original_len {
176 let src: Vec<_> = self
177 .vector
178 .iter()
179 .map(|(d, w)| format!("({},{:.4})", d, w))
180 .collect();
181 let pruned_fmt: Vec<_> = v.iter().map(|(d, w)| format!("({},{:.4})", d, w)).collect();
182 log::debug!(
183 "[sparse query] field={}: pruned {}->{} dims \
184 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{}), \
185 source=[{}], pruned=[{}]",
186 self.field.0,
187 original_len,
188 v.len(),
189 original_len,
190 after_threshold,
191 after_threshold,
192 after_pruning,
193 after_pruning,
194 v.len(),
195 src.join(", "),
196 pruned_fmt.join(", "),
197 );
198 }
199
200 v
201 }
202
203 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
205 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
206 Self::new(field, vector)
207 }
208
209 #[cfg(feature = "native")]
221 pub fn from_text(
222 field: Field,
223 text: &str,
224 tokenizer_name: &str,
225 weighting: crate::structures::QueryWeighting,
226 sparse_index: Option<&crate::segment::SparseIndex>,
227 ) -> crate::Result<Self> {
228 use crate::structures::QueryWeighting;
229 use crate::tokenizer::tokenizer_cache;
230
231 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
232 let token_ids = tokenizer.tokenize_unique(text)?;
233
234 let weights: Vec<f32> = match weighting {
235 QueryWeighting::One => vec![1.0f32; token_ids.len()],
236 QueryWeighting::Idf => {
237 if let Some(index) = sparse_index {
238 index.idf_weights(&token_ids)
239 } else {
240 vec![1.0f32; token_ids.len()]
241 }
242 }
243 QueryWeighting::IdfFile => {
244 use crate::tokenizer::idf_weights_cache;
245 if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name) {
246 token_ids.iter().map(|&id| idf.get(id)).collect()
247 } else {
248 vec![1.0f32; token_ids.len()]
249 }
250 }
251 };
252
253 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
254 Ok(Self::new(field, vector))
255 }
256
257 #[cfg(feature = "native")]
269 pub fn from_text_with_stats(
270 field: Field,
271 text: &str,
272 tokenizer: &crate::tokenizer::HfTokenizer,
273 weighting: crate::structures::QueryWeighting,
274 global_stats: Option<&crate::query::GlobalStats>,
275 ) -> crate::Result<Self> {
276 use crate::structures::QueryWeighting;
277
278 let token_ids = tokenizer.tokenize_unique(text)?;
279
280 let weights: Vec<f32> = match weighting {
281 QueryWeighting::One => vec![1.0f32; token_ids.len()],
282 QueryWeighting::Idf => {
283 if let Some(stats) = global_stats {
284 stats
286 .sparse_idf_weights(field, &token_ids)
287 .into_iter()
288 .map(|w| w.max(0.0))
289 .collect()
290 } else {
291 vec![1.0f32; token_ids.len()]
292 }
293 }
294 QueryWeighting::IdfFile => {
295 vec![1.0f32; token_ids.len()]
298 }
299 };
300
301 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
302 Ok(Self::new(field, vector))
303 }
304
305 #[cfg(feature = "native")]
317 pub fn from_text_with_tokenizer_bytes(
318 field: Field,
319 text: &str,
320 tokenizer_bytes: &[u8],
321 weighting: crate::structures::QueryWeighting,
322 global_stats: Option<&crate::query::GlobalStats>,
323 ) -> crate::Result<Self> {
324 use crate::structures::QueryWeighting;
325 use crate::tokenizer::HfTokenizer;
326
327 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
328 let token_ids = tokenizer.tokenize_unique(text)?;
329
330 let weights: Vec<f32> = match weighting {
331 QueryWeighting::One => vec![1.0f32; token_ids.len()],
332 QueryWeighting::Idf => {
333 if let Some(stats) = global_stats {
334 stats
336 .sparse_idf_weights(field, &token_ids)
337 .into_iter()
338 .map(|w| w.max(0.0))
339 .collect()
340 } else {
341 vec![1.0f32; token_ids.len()]
342 }
343 }
344 QueryWeighting::IdfFile => {
345 vec![1.0f32; token_ids.len()]
348 }
349 };
350
351 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
352 Ok(Self::new(field, vector))
353 }
354}
355
356impl SparseVectorQuery {
357 fn sparse_infos(&self) -> Vec<crate::query::SparseTermQueryInfo> {
359 self.pruned_dims()
360 .iter()
361 .map(|&(dim_id, weight)| crate::query::SparseTermQueryInfo {
362 field: self.field,
363 dim_id,
364 weight,
365 heap_factor: self.heap_factor,
366 combiner: self.combiner,
367 over_fetch_factor: self.over_fetch_factor,
368 })
369 .collect()
370 }
371}
372
373impl Query for SparseVectorQuery {
374 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
375 let infos = self.sparse_infos();
376 let field = self.field;
377 let heap_factor = self.heap_factor;
378 let combiner = self.combiner;
379 let over_fetch_factor = self.over_fetch_factor;
380
381 Box::pin(async move {
382 if infos.is_empty() {
383 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>);
384 }
385
386 if infos.len() == 1 {
388 let info = &infos[0];
389 let term = SparseTermQuery::new(field, info.dim_id, info.weight)
390 .with_heap_factor(heap_factor)
391 .with_combiner(combiner)
392 .with_over_fetch_factor(over_fetch_factor);
393 return term.scorer(reader, limit).await;
394 }
395
396 if let Some((executor, info)) =
398 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
399 {
400 let raw = executor.execute().await?;
401 return Ok(crate::query::planner::combine_sparse_results(
402 raw,
403 info.combiner,
404 info.field,
405 limit,
406 ));
407 }
408
409 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>)
410 })
411 }
412
413 #[cfg(feature = "sync")]
414 fn scorer_sync<'a>(
415 &self,
416 reader: &'a SegmentReader,
417 limit: usize,
418 ) -> crate::Result<Box<dyn Scorer + 'a>> {
419 let infos = self.sparse_infos();
420 if infos.is_empty() {
421 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>);
422 }
423
424 if infos.len() == 1 {
426 let info = &infos[0];
427 let term = SparseTermQuery::new(self.field, info.dim_id, info.weight)
428 .with_heap_factor(self.heap_factor)
429 .with_combiner(self.combiner)
430 .with_over_fetch_factor(self.over_fetch_factor);
431 return term.scorer_sync(reader, limit);
432 }
433
434 if let Some((executor, info)) =
436 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
437 {
438 let raw = executor.execute_sync()?;
439 return Ok(crate::query::planner::combine_sparse_results(
440 raw,
441 info.combiner,
442 info.field,
443 limit,
444 ));
445 }
446
447 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>)
448 }
449
450 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
451 Box::pin(async move { Ok(u32::MAX) })
452 }
453
454 fn decompose(&self) -> crate::query::QueryDecomposition {
455 let infos = self.sparse_infos();
456 if infos.is_empty() {
457 crate::query::QueryDecomposition::Opaque
458 } else {
459 crate::query::QueryDecomposition::SparseTerms(infos)
460 }
461 }
462}
463
464#[derive(Debug, Clone)]
472pub struct SparseTermQuery {
473 pub field: Field,
474 pub dim_id: u32,
475 pub weight: f32,
476 pub heap_factor: f32,
478 pub combiner: MultiValueCombiner,
480 pub over_fetch_factor: f32,
482}
483
484impl std::fmt::Display for SparseTermQuery {
485 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486 write!(
487 f,
488 "SparseTerm({}, dim={}, w={:.3})",
489 self.field.0, self.dim_id, self.weight
490 )
491 }
492}
493
494impl SparseTermQuery {
495 pub fn new(field: Field, dim_id: u32, weight: f32) -> Self {
496 Self {
497 field,
498 dim_id,
499 weight,
500 heap_factor: 1.0,
501 combiner: MultiValueCombiner::default(),
502 over_fetch_factor: 2.0,
503 }
504 }
505
506 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
507 self.heap_factor = heap_factor;
508 self
509 }
510
511 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
512 self.combiner = combiner;
513 self
514 }
515
516 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
517 self.over_fetch_factor = factor.max(1.0);
518 self
519 }
520
521 fn make_scorer<'a>(
524 &self,
525 reader: &'a SegmentReader,
526 ) -> crate::Result<Option<SparseTermScorer<'a>>> {
527 let si = match reader.sparse_index(self.field) {
528 Some(si) => si,
529 None => return Ok(None),
530 };
531 let (skip_start, skip_count, global_max, block_data_offset) =
532 match si.get_skip_range_full(self.dim_id) {
533 Some(v) => v,
534 None => return Ok(None),
535 };
536 let cursor = crate::query::TermCursor::sparse(
537 si,
538 self.weight,
539 skip_start,
540 skip_count,
541 global_max,
542 block_data_offset,
543 );
544 Ok(Some(SparseTermScorer {
545 cursor,
546 field_id: self.field.0,
547 }))
548 }
549}
550
551impl Query for SparseTermQuery {
552 fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
553 let query = self.clone();
554 Box::pin(async move {
555 let mut scorer = match query.make_scorer(reader)? {
556 Some(s) => s,
557 None => return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>),
558 };
559 scorer.cursor.ensure_block_loaded().await.ok();
560 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
561 })
562 }
563
564 #[cfg(feature = "sync")]
565 fn scorer_sync<'a>(
566 &self,
567 reader: &'a SegmentReader,
568 _limit: usize,
569 ) -> crate::Result<Box<dyn Scorer + 'a>> {
570 let mut scorer = match self.make_scorer(reader)? {
571 Some(s) => s,
572 None => return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>),
573 };
574 scorer.cursor.ensure_block_loaded_sync().ok();
575 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
576 }
577
578 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
579 let field = self.field;
580 let dim_id = self.dim_id;
581 Box::pin(async move {
582 let si = match reader.sparse_index(field) {
583 Some(si) => si,
584 None => return Ok(0),
585 };
586 match si.get_skip_range_full(dim_id) {
587 Some((_, skip_count, _, _)) => Ok((skip_count * 256) as u32),
588 None => Ok(0),
589 }
590 })
591 }
592
593 fn decompose(&self) -> crate::query::QueryDecomposition {
594 crate::query::QueryDecomposition::SparseTerms(vec![crate::query::SparseTermQueryInfo {
595 field: self.field,
596 dim_id: self.dim_id,
597 weight: self.weight,
598 heap_factor: self.heap_factor,
599 combiner: self.combiner,
600 over_fetch_factor: self.over_fetch_factor,
601 }])
602 }
603}
604
605struct SparseTermScorer<'a> {
610 cursor: crate::query::TermCursor<'a>,
611 field_id: u32,
612}
613
614impl crate::query::docset::DocSet for SparseTermScorer<'_> {
615 fn doc(&self) -> DocId {
616 let d = self.cursor.doc();
617 if d == u32::MAX { TERMINATED } else { d }
618 }
619
620 fn advance(&mut self) -> DocId {
621 match self.cursor.advance_sync() {
622 Ok(d) if d == u32::MAX => TERMINATED,
623 Ok(d) => d,
624 Err(_) => TERMINATED,
625 }
626 }
627
628 fn seek(&mut self, target: DocId) -> DocId {
629 match self.cursor.seek_sync(target) {
630 Ok(d) if d == u32::MAX => TERMINATED,
631 Ok(d) => d,
632 Err(_) => TERMINATED,
633 }
634 }
635
636 fn size_hint(&self) -> u32 {
637 0
638 }
639}
640
641impl Scorer for SparseTermScorer<'_> {
642 fn score(&self) -> Score {
643 self.cursor.score()
644 }
645
646 fn matched_positions(&self) -> Option<MatchedPositions> {
647 let ordinal = self.cursor.ordinal();
648 let score = self.cursor.score();
649 if score == 0.0 {
650 return None;
651 }
652 Some(vec![(
653 self.field_id,
654 vec![ScoredPosition::new(ordinal as u32, score)],
655 )])
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662 use crate::dsl::Field;
663
664 #[test]
665 fn test_sparse_vector_query_new() {
666 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
667 let query = SparseVectorQuery::new(Field(0), sparse.clone());
668
669 assert_eq!(query.field, Field(0));
670 assert_eq!(query.vector, sparse);
671 }
672
673 #[test]
674 fn test_sparse_vector_query_from_indices_weights() {
675 let query =
676 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
677
678 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
679 }
680}