1use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::combiner::MultiValueCombiner;
8use crate::query::ScoredPosition;
9use crate::query::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
10
11#[derive(Debug, Clone)]
13pub struct SparseVectorQuery {
14 pub field: Field,
16 pub vector: Vec<(u32, f32)>,
18 pub combiner: MultiValueCombiner,
20 pub heap_factor: f32,
23 pub weight_threshold: f32,
26 pub max_query_dims: Option<usize>,
29 pub pruning: Option<f32>,
33 pub over_fetch_factor: f32,
35 pruned: Option<Vec<(u32, f32)>>,
37}
38
39impl std::fmt::Display for SparseVectorQuery {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 let dims = self.pruned_dims();
42 write!(f, "Sparse({}, dims={}", self.field.0, dims.len())?;
43 if self.heap_factor < 1.0 {
44 write!(f, ", heap={}", self.heap_factor)?;
45 }
46 if self.vector.len() != dims.len() {
47 write!(f, ", orig={}", self.vector.len())?;
48 }
49 write!(f, ")")
50 }
51}
52
53impl SparseVectorQuery {
54 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
61 Self {
62 field,
63 vector,
64 combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
65 heap_factor: 1.0,
66 weight_threshold: 0.0,
67 max_query_dims: None,
68 pruning: None,
69 over_fetch_factor: 2.0,
70 pruned: None,
71 }
72 }
73
74 pub(crate) fn pruned_dims(&self) -> &[(u32, f32)] {
76 self.pruned.as_deref().unwrap_or(&self.vector)
77 }
78
79 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
81 self.combiner = combiner;
82 self
83 }
84
85 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
90 self.over_fetch_factor = factor.max(1.0);
91 self
92 }
93
94 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
101 self.heap_factor = heap_factor.clamp(0.0, 1.0);
102 self
103 }
104
105 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
108 self.weight_threshold = threshold;
109 self.pruned = Some(self.compute_pruned_vector());
110 self
111 }
112
113 pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
115 self.max_query_dims = Some(max_dims);
116 self.pruned = Some(self.compute_pruned_vector());
117 self
118 }
119
120 pub fn with_pruning(mut self, fraction: f32) -> Self {
123 self.pruning = Some(fraction.clamp(0.0, 1.0));
124 self.pruned = Some(self.compute_pruned_vector());
125 self
126 }
127
128 fn compute_pruned_vector(&self) -> Vec<(u32, f32)> {
130 let original_len = self.vector.len();
131
132 let mut v: Vec<(u32, f32)> = if self.weight_threshold > 0.0 {
134 self.vector
135 .iter()
136 .copied()
137 .filter(|(_, w)| w.abs() >= self.weight_threshold)
138 .collect()
139 } else {
140 self.vector.clone()
141 };
142 let after_threshold = v.len();
143
144 let mut sorted_by_weight = false;
146 if let Some(fraction) = self.pruning
147 && fraction < 1.0
148 && v.len() > 1
149 {
150 v.sort_unstable_by(|a, b| {
151 b.1.abs()
152 .partial_cmp(&a.1.abs())
153 .unwrap_or(std::cmp::Ordering::Equal)
154 });
155 sorted_by_weight = true;
156 let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
157 v.truncate(keep);
158 }
159 let after_pruning = v.len();
160
161 if let Some(max_dims) = self.max_query_dims
163 && v.len() > max_dims
164 {
165 if !sorted_by_weight {
166 v.sort_unstable_by(|a, b| {
167 b.1.abs()
168 .partial_cmp(&a.1.abs())
169 .unwrap_or(std::cmp::Ordering::Equal)
170 });
171 }
172 v.truncate(max_dims);
173 }
174
175 if v.len() < original_len {
176 let src: Vec<_> = self
177 .vector
178 .iter()
179 .map(|(d, w)| format!("({},{:.4})", d, w))
180 .collect();
181 let pruned_fmt: Vec<_> = v.iter().map(|(d, w)| format!("({},{:.4})", d, w)).collect();
182 log::debug!(
183 "[sparse query] field={}: pruned {}->{} dims \
184 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{}), \
185 source=[{}], pruned=[{}]",
186 self.field.0,
187 original_len,
188 v.len(),
189 original_len,
190 after_threshold,
191 after_threshold,
192 after_pruning,
193 after_pruning,
194 v.len(),
195 src.join(", "),
196 pruned_fmt.join(", "),
197 );
198 }
199
200 v
201 }
202
203 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
205 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
206 Self::new(field, vector)
207 }
208
209 #[cfg(feature = "native")]
221 pub fn from_text(
222 field: Field,
223 text: &str,
224 tokenizer_name: &str,
225 weighting: crate::structures::QueryWeighting,
226 sparse_index: Option<&crate::segment::SparseIndex>,
227 ) -> crate::Result<Self> {
228 use crate::structures::QueryWeighting;
229 use crate::tokenizer::tokenizer_cache;
230
231 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
232 let token_ids = tokenizer.tokenize_unique(text)?;
233
234 let weights: Vec<f32> = match weighting {
235 QueryWeighting::One => vec![1.0f32; token_ids.len()],
236 QueryWeighting::Idf => {
237 if let Some(index) = sparse_index {
238 index.idf_weights(&token_ids)
239 } else {
240 vec![1.0f32; token_ids.len()]
241 }
242 }
243 QueryWeighting::IdfFile => {
244 use crate::tokenizer::idf_weights_cache;
245 if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name) {
246 token_ids.iter().map(|&id| idf.get(id)).collect()
247 } else {
248 vec![1.0f32; token_ids.len()]
249 }
250 }
251 };
252
253 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
254 Ok(Self::new(field, vector))
255 }
256
257 #[cfg(feature = "native")]
269 pub fn from_text_with_stats(
270 field: Field,
271 text: &str,
272 tokenizer: &crate::tokenizer::HfTokenizer,
273 weighting: crate::structures::QueryWeighting,
274 global_stats: Option<&crate::query::GlobalStats>,
275 ) -> crate::Result<Self> {
276 use crate::structures::QueryWeighting;
277
278 let token_ids = tokenizer.tokenize_unique(text)?;
279
280 let weights: Vec<f32> = match weighting {
281 QueryWeighting::One => vec![1.0f32; token_ids.len()],
282 QueryWeighting::Idf => {
283 if let Some(stats) = global_stats {
284 stats
286 .sparse_idf_weights(field, &token_ids)
287 .into_iter()
288 .map(|w| w.max(0.0))
289 .collect()
290 } else {
291 vec![1.0f32; token_ids.len()]
292 }
293 }
294 QueryWeighting::IdfFile => {
295 vec![1.0f32; token_ids.len()]
298 }
299 };
300
301 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
302 Ok(Self::new(field, vector))
303 }
304
305 #[cfg(feature = "native")]
317 pub fn from_text_with_tokenizer_bytes(
318 field: Field,
319 text: &str,
320 tokenizer_bytes: &[u8],
321 weighting: crate::structures::QueryWeighting,
322 global_stats: Option<&crate::query::GlobalStats>,
323 ) -> crate::Result<Self> {
324 use crate::structures::QueryWeighting;
325 use crate::tokenizer::HfTokenizer;
326
327 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
328 let token_ids = tokenizer.tokenize_unique(text)?;
329
330 let weights: Vec<f32> = match weighting {
331 QueryWeighting::One => vec![1.0f32; token_ids.len()],
332 QueryWeighting::Idf => {
333 if let Some(stats) = global_stats {
334 stats
336 .sparse_idf_weights(field, &token_ids)
337 .into_iter()
338 .map(|w| w.max(0.0))
339 .collect()
340 } else {
341 vec![1.0f32; token_ids.len()]
342 }
343 }
344 QueryWeighting::IdfFile => {
345 vec![1.0f32; token_ids.len()]
348 }
349 };
350
351 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
352 Ok(Self::new(field, vector))
353 }
354}
355
356impl SparseVectorQuery {
357 fn sparse_infos(&self) -> Vec<crate::query::SparseTermQueryInfo> {
359 self.pruned_dims()
360 .iter()
361 .map(|&(dim_id, weight)| crate::query::SparseTermQueryInfo {
362 field: self.field,
363 dim_id,
364 weight,
365 heap_factor: self.heap_factor,
366 combiner: self.combiner,
367 over_fetch_factor: self.over_fetch_factor,
368 })
369 .collect()
370 }
371}
372
373impl Query for SparseVectorQuery {
374 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
375 let infos = self.sparse_infos();
376
377 Box::pin(async move {
378 if infos.is_empty() {
379 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>);
380 }
381
382 if let Some((executor, info)) =
384 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
385 {
386 let raw = executor.execute().await?;
387 return Ok(crate::query::planner::combine_sparse_results(
388 raw,
389 info.combiner,
390 info.field,
391 limit,
392 ));
393 }
394
395 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>)
396 })
397 }
398
399 #[cfg(feature = "sync")]
400 fn scorer_sync<'a>(
401 &self,
402 reader: &'a SegmentReader,
403 limit: usize,
404 ) -> crate::Result<Box<dyn Scorer + 'a>> {
405 let infos = self.sparse_infos();
406 if infos.is_empty() {
407 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>);
408 }
409
410 if let Some((executor, info)) =
412 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
413 {
414 let raw = executor.execute_sync()?;
415 return Ok(crate::query::planner::combine_sparse_results(
416 raw,
417 info.combiner,
418 info.field,
419 limit,
420 ));
421 }
422
423 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>)
424 }
425
426 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
427 Box::pin(async move { Ok(u32::MAX) })
428 }
429
430 fn decompose(&self) -> crate::query::QueryDecomposition {
431 let infos = self.sparse_infos();
432 if infos.is_empty() {
433 crate::query::QueryDecomposition::Opaque
434 } else {
435 crate::query::QueryDecomposition::SparseTerms(infos)
436 }
437 }
438}
439
440#[derive(Debug, Clone)]
448pub struct SparseTermQuery {
449 pub field: Field,
450 pub dim_id: u32,
451 pub weight: f32,
452 pub heap_factor: f32,
454 pub combiner: MultiValueCombiner,
456 pub over_fetch_factor: f32,
458}
459
460impl std::fmt::Display for SparseTermQuery {
461 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462 write!(
463 f,
464 "SparseTerm({}, dim={}, w={:.3})",
465 self.field.0, self.dim_id, self.weight
466 )
467 }
468}
469
470impl SparseTermQuery {
471 pub fn new(field: Field, dim_id: u32, weight: f32) -> Self {
472 Self {
473 field,
474 dim_id,
475 weight,
476 heap_factor: 1.0,
477 combiner: MultiValueCombiner::default(),
478 over_fetch_factor: 2.0,
479 }
480 }
481
482 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
483 self.heap_factor = heap_factor;
484 self
485 }
486
487 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
488 self.combiner = combiner;
489 self
490 }
491
492 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
493 self.over_fetch_factor = factor.max(1.0);
494 self
495 }
496
497 fn make_scorer<'a>(
500 &self,
501 reader: &'a SegmentReader,
502 ) -> crate::Result<Option<SparseTermScorer<'a>>> {
503 let si = match reader.sparse_index(self.field) {
504 Some(si) => si,
505 None => return Ok(None),
506 };
507 let (skip_start, skip_count, global_max, block_data_offset) =
508 match si.get_skip_range_full(self.dim_id) {
509 Some(v) => v,
510 None => return Ok(None),
511 };
512 let cursor = crate::query::TermCursor::sparse(
513 si,
514 self.weight,
515 skip_start,
516 skip_count,
517 global_max,
518 block_data_offset,
519 );
520 Ok(Some(SparseTermScorer {
521 cursor,
522 field_id: self.field.0,
523 }))
524 }
525}
526
527impl Query for SparseTermQuery {
528 fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
529 let query = self.clone();
530 Box::pin(async move {
531 let mut scorer = match query.make_scorer(reader)? {
532 Some(s) => s,
533 None => return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>),
534 };
535 scorer.cursor.ensure_block_loaded().await.ok();
536 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
537 })
538 }
539
540 #[cfg(feature = "sync")]
541 fn scorer_sync<'a>(
542 &self,
543 reader: &'a SegmentReader,
544 _limit: usize,
545 ) -> crate::Result<Box<dyn Scorer + 'a>> {
546 let mut scorer = match self.make_scorer(reader)? {
547 Some(s) => s,
548 None => return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>),
549 };
550 scorer.cursor.ensure_block_loaded_sync().ok();
551 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
552 }
553
554 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
555 let field = self.field;
556 let dim_id = self.dim_id;
557 Box::pin(async move {
558 let si = match reader.sparse_index(field) {
559 Some(si) => si,
560 None => return Ok(0),
561 };
562 match si.get_skip_range_full(dim_id) {
563 Some((_, skip_count, _, _)) => Ok((skip_count * 256) as u32),
564 None => Ok(0),
565 }
566 })
567 }
568
569 fn decompose(&self) -> crate::query::QueryDecomposition {
570 crate::query::QueryDecomposition::SparseTerms(vec![crate::query::SparseTermQueryInfo {
571 field: self.field,
572 dim_id: self.dim_id,
573 weight: self.weight,
574 heap_factor: self.heap_factor,
575 combiner: self.combiner,
576 over_fetch_factor: self.over_fetch_factor,
577 }])
578 }
579}
580
581struct SparseTermScorer<'a> {
586 cursor: crate::query::TermCursor<'a>,
587 field_id: u32,
588}
589
590impl crate::query::docset::DocSet for SparseTermScorer<'_> {
591 fn doc(&self) -> DocId {
592 let d = self.cursor.doc();
593 if d == u32::MAX { TERMINATED } else { d }
594 }
595
596 fn advance(&mut self) -> DocId {
597 match self.cursor.advance_sync() {
598 Ok(d) if d == u32::MAX => TERMINATED,
599 Ok(d) => d,
600 Err(_) => TERMINATED,
601 }
602 }
603
604 fn seek(&mut self, target: DocId) -> DocId {
605 match self.cursor.seek_sync(target) {
606 Ok(d) if d == u32::MAX => TERMINATED,
607 Ok(d) => d,
608 Err(_) => TERMINATED,
609 }
610 }
611
612 fn size_hint(&self) -> u32 {
613 0
614 }
615}
616
617impl Scorer for SparseTermScorer<'_> {
618 fn score(&self) -> Score {
619 self.cursor.score()
620 }
621
622 fn matched_positions(&self) -> Option<MatchedPositions> {
623 let ordinal = self.cursor.ordinal();
624 let score = self.cursor.score();
625 if score == 0.0 {
626 return None;
627 }
628 Some(vec![(
629 self.field_id,
630 vec![ScoredPosition::new(ordinal as u32, score)],
631 )])
632 }
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use crate::dsl::Field;
639
640 #[test]
641 fn test_sparse_vector_query_new() {
642 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
643 let query = SparseVectorQuery::new(Field(0), sparse.clone());
644
645 assert_eq!(query.field, Field(0));
646 assert_eq!(query.vector, sparse);
647 }
648
649 #[test]
650 fn test_sparse_vector_query_from_indices_weights() {
651 let query =
652 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
653
654 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
655 }
656}