1use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::combiner::MultiValueCombiner;
8use crate::query::ScoredPosition;
9use crate::query::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
10
11#[derive(Debug, Clone)]
13pub struct SparseVectorQuery {
14 pub field: Field,
16 pub vector: Vec<(u32, f32)>,
18 pub combiner: MultiValueCombiner,
20 pub heap_factor: f32,
23 pub weight_threshold: f32,
26 pub max_query_dims: Option<usize>,
29 pub pruning: Option<f32>,
33 pub over_fetch_factor: f32,
35 pruned: Option<Vec<(u32, f32)>>,
37}
38
39impl std::fmt::Display for SparseVectorQuery {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 let dims = self.pruned_dims();
42 write!(f, "Sparse({}, dims={}", self.field.0, dims.len())?;
43 if self.heap_factor < 1.0 {
44 write!(f, ", heap={}", self.heap_factor)?;
45 }
46 if self.vector.len() != dims.len() {
47 write!(f, ", orig={}", self.vector.len())?;
48 }
49 write!(f, ")")
50 }
51}
52
53impl SparseVectorQuery {
54 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
61 let mut q = Self {
62 field,
63 vector,
64 combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
65 heap_factor: 1.0,
66 weight_threshold: 0.0,
67 max_query_dims: Some(crate::query::MAX_QUERY_TOKENS),
68 pruning: None,
69 over_fetch_factor: 2.0,
70 pruned: None,
71 };
72 q.pruned = Some(q.compute_pruned_vector());
73 q
74 }
75
76 pub(crate) fn pruned_dims(&self) -> &[(u32, f32)] {
78 self.pruned.as_deref().unwrap_or(&self.vector)
79 }
80
81 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
83 self.combiner = combiner;
84 self
85 }
86
87 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
92 self.over_fetch_factor = factor.max(1.0);
93 self
94 }
95
96 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
103 self.heap_factor = heap_factor.clamp(0.0, 1.0);
104 self
105 }
106
107 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
110 self.weight_threshold = threshold;
111 self.pruned = Some(self.compute_pruned_vector());
112 self
113 }
114
115 pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
117 self.max_query_dims = Some(max_dims);
118 self.pruned = Some(self.compute_pruned_vector());
119 self
120 }
121
122 pub fn with_pruning(mut self, fraction: f32) -> Self {
125 self.pruning = Some(fraction.clamp(0.0, 1.0));
126 self.pruned = Some(self.compute_pruned_vector());
127 self
128 }
129
130 fn compute_pruned_vector(&self) -> Vec<(u32, f32)> {
132 let original_len = self.vector.len();
133
134 let mut v: Vec<(u32, f32)> = if self.weight_threshold > 0.0 {
136 self.vector
137 .iter()
138 .copied()
139 .filter(|(_, w)| w.abs() >= self.weight_threshold)
140 .collect()
141 } else {
142 self.vector.clone()
143 };
144 let after_threshold = v.len();
145
146 let mut sorted_by_weight = false;
148 if let Some(fraction) = self.pruning
149 && fraction < 1.0
150 && v.len() > 1
151 {
152 v.sort_unstable_by(|a, b| {
153 b.1.abs()
154 .partial_cmp(&a.1.abs())
155 .unwrap_or(std::cmp::Ordering::Equal)
156 });
157 sorted_by_weight = true;
158 let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
159 v.truncate(keep);
160 }
161 let after_pruning = v.len();
162
163 if let Some(max_dims) = self.max_query_dims
165 && v.len() > max_dims
166 {
167 if !sorted_by_weight {
168 v.sort_unstable_by(|a, b| {
169 b.1.abs()
170 .partial_cmp(&a.1.abs())
171 .unwrap_or(std::cmp::Ordering::Equal)
172 });
173 }
174 v.truncate(max_dims);
175 }
176
177 if v.len() < original_len && log::log_enabled!(log::Level::Debug) {
178 let src: Vec<_> = self
179 .vector
180 .iter()
181 .map(|(d, w)| format!("({},{:.4})", d, w))
182 .collect();
183 let pruned_fmt: Vec<_> = v.iter().map(|(d, w)| format!("({},{:.4})", d, w)).collect();
184 log::debug!(
185 "[sparse query] field={}: pruned {}->{} dims \
186 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{}), \
187 source=[{}], pruned=[{}]",
188 self.field.0,
189 original_len,
190 v.len(),
191 original_len,
192 after_threshold,
193 after_threshold,
194 after_pruning,
195 after_pruning,
196 v.len(),
197 src.join(", "),
198 pruned_fmt.join(", "),
199 );
200 }
201
202 v
203 }
204
205 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
207 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
208 Self::new(field, vector)
209 }
210
211 #[cfg(feature = "native")]
223 pub fn from_text(
224 field: Field,
225 text: &str,
226 tokenizer_name: &str,
227 weighting: crate::structures::QueryWeighting,
228 sparse_index: Option<&crate::segment::SparseIndex>,
229 ) -> crate::Result<Self> {
230 use crate::structures::QueryWeighting;
231 use crate::tokenizer::tokenizer_cache;
232
233 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
234 let token_ids = tokenizer.tokenize_unique(text)?;
235
236 let weights: Vec<f32> = match weighting {
237 QueryWeighting::One => vec![1.0f32; token_ids.len()],
238 QueryWeighting::Idf => {
239 if let Some(index) = sparse_index {
240 index.idf_weights(&token_ids)
241 } else {
242 vec![1.0f32; token_ids.len()]
243 }
244 }
245 QueryWeighting::IdfFile => {
246 use crate::tokenizer::idf_weights_cache;
247 if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name, None) {
248 token_ids.iter().map(|&id| idf.get(id)).collect()
249 } else {
250 vec![1.0f32; token_ids.len()]
251 }
252 }
253 };
254
255 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
256 Ok(Self::new(field, vector))
257 }
258
259 #[cfg(feature = "native")]
271 pub fn from_text_with_stats(
272 field: Field,
273 text: &str,
274 tokenizer: &crate::tokenizer::HfTokenizer,
275 weighting: crate::structures::QueryWeighting,
276 global_stats: Option<&crate::query::GlobalStats>,
277 ) -> crate::Result<Self> {
278 use crate::structures::QueryWeighting;
279
280 let token_ids = tokenizer.tokenize_unique(text)?;
281
282 let weights: Vec<f32> = match weighting {
283 QueryWeighting::One => vec![1.0f32; token_ids.len()],
284 QueryWeighting::Idf => {
285 if let Some(stats) = global_stats {
286 stats
288 .sparse_idf_weights(field, &token_ids)
289 .into_iter()
290 .map(|w| w.max(0.0))
291 .collect()
292 } else {
293 vec![1.0f32; token_ids.len()]
294 }
295 }
296 QueryWeighting::IdfFile => {
297 vec![1.0f32; token_ids.len()]
300 }
301 };
302
303 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
304 Ok(Self::new(field, vector))
305 }
306
307 #[cfg(feature = "native")]
319 pub fn from_text_with_tokenizer_bytes(
320 field: Field,
321 text: &str,
322 tokenizer_bytes: &[u8],
323 weighting: crate::structures::QueryWeighting,
324 global_stats: Option<&crate::query::GlobalStats>,
325 ) -> crate::Result<Self> {
326 use crate::structures::QueryWeighting;
327 use crate::tokenizer::HfTokenizer;
328
329 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
330 let token_ids = tokenizer.tokenize_unique(text)?;
331
332 let weights: Vec<f32> = match weighting {
333 QueryWeighting::One => vec![1.0f32; token_ids.len()],
334 QueryWeighting::Idf => {
335 if let Some(stats) = global_stats {
336 stats
338 .sparse_idf_weights(field, &token_ids)
339 .into_iter()
340 .map(|w| w.max(0.0))
341 .collect()
342 } else {
343 vec![1.0f32; token_ids.len()]
344 }
345 }
346 QueryWeighting::IdfFile => {
347 vec![1.0f32; token_ids.len()]
350 }
351 };
352
353 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
354 Ok(Self::new(field, vector))
355 }
356}
357
358impl SparseVectorQuery {
359 fn sparse_infos(&self) -> Vec<crate::query::SparseTermQueryInfo> {
361 self.pruned_dims()
362 .iter()
363 .map(|&(dim_id, weight)| crate::query::SparseTermQueryInfo {
364 field: self.field,
365 dim_id,
366 weight,
367 heap_factor: self.heap_factor,
368 combiner: self.combiner,
369 over_fetch_factor: self.over_fetch_factor,
370 })
371 .collect()
372 }
373}
374
375impl Query for SparseVectorQuery {
376 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
377 let infos = self.sparse_infos();
378
379 Box::pin(async move {
380 if infos.is_empty() {
381 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>);
382 }
383
384 if let Some((executor, info)) =
386 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
387 {
388 let raw = executor.execute().await?;
389 return Ok(crate::query::planner::combine_sparse_results(
390 raw,
391 info.combiner,
392 info.field,
393 limit,
394 ));
395 }
396
397 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer>)
398 })
399 }
400
401 #[cfg(feature = "sync")]
402 fn scorer_sync<'a>(
403 &self,
404 reader: &'a SegmentReader,
405 limit: usize,
406 ) -> crate::Result<Box<dyn Scorer + 'a>> {
407 let infos = self.sparse_infos();
408 if infos.is_empty() {
409 return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>);
410 }
411
412 if let Some((executor, info)) =
414 crate::query::planner::build_sparse_maxscore_executor(&infos, reader, limit, None)
415 {
416 let raw = executor.execute_sync()?;
417 return Ok(crate::query::planner::combine_sparse_results(
418 raw,
419 info.combiner,
420 info.field,
421 limit,
422 ));
423 }
424
425 Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>)
426 }
427
428 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
429 Box::pin(async move { Ok(u32::MAX) })
430 }
431
432 fn decompose(&self) -> crate::query::QueryDecomposition {
433 let infos = self.sparse_infos();
434 if infos.is_empty() {
435 crate::query::QueryDecomposition::Opaque
436 } else {
437 crate::query::QueryDecomposition::SparseTerms(infos)
438 }
439 }
440}
441
442#[derive(Debug, Clone)]
450pub struct SparseTermQuery {
451 pub field: Field,
452 pub dim_id: u32,
453 pub weight: f32,
454 pub heap_factor: f32,
456 pub combiner: MultiValueCombiner,
458 pub over_fetch_factor: f32,
460}
461
462impl std::fmt::Display for SparseTermQuery {
463 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
464 write!(
465 f,
466 "SparseTerm({}, dim={}, w={:.3})",
467 self.field.0, self.dim_id, self.weight
468 )
469 }
470}
471
472impl SparseTermQuery {
473 pub fn new(field: Field, dim_id: u32, weight: f32) -> Self {
474 Self {
475 field,
476 dim_id,
477 weight,
478 heap_factor: 1.0,
479 combiner: MultiValueCombiner::default(),
480 over_fetch_factor: 2.0,
481 }
482 }
483
484 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
485 self.heap_factor = heap_factor;
486 self
487 }
488
489 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
490 self.combiner = combiner;
491 self
492 }
493
494 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
495 self.over_fetch_factor = factor.max(1.0);
496 self
497 }
498
499 fn make_scorer<'a>(
502 &self,
503 reader: &'a SegmentReader,
504 ) -> crate::Result<Option<SparseTermScorer<'a>>> {
505 let si = match reader.sparse_index(self.field) {
506 Some(si) => si,
507 None => return Ok(None),
508 };
509 let (skip_start, skip_count, global_max, block_data_offset) =
510 match si.get_skip_range_full(self.dim_id) {
511 Some(v) => v,
512 None => return Ok(None),
513 };
514 let cursor = crate::query::TermCursor::sparse(
515 si,
516 self.weight,
517 skip_start,
518 skip_count,
519 global_max,
520 block_data_offset,
521 );
522 Ok(Some(SparseTermScorer {
523 cursor,
524 field_id: self.field.0,
525 }))
526 }
527}
528
529impl Query for SparseTermQuery {
530 fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
531 let query = self.clone();
532 Box::pin(async move {
533 let mut scorer = match query.make_scorer(reader)? {
534 Some(s) => s,
535 None => return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>),
536 };
537 scorer.cursor.ensure_block_loaded().await.ok();
538 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
539 })
540 }
541
542 #[cfg(feature = "sync")]
543 fn scorer_sync<'a>(
544 &self,
545 reader: &'a SegmentReader,
546 _limit: usize,
547 ) -> crate::Result<Box<dyn Scorer + 'a>> {
548 let mut scorer = match self.make_scorer(reader)? {
549 Some(s) => s,
550 None => return Ok(Box::new(crate::query::EmptyScorer) as Box<dyn Scorer + 'a>),
551 };
552 scorer.cursor.ensure_block_loaded_sync().ok();
553 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
554 }
555
556 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
557 let field = self.field;
558 let dim_id = self.dim_id;
559 Box::pin(async move {
560 let si = match reader.sparse_index(field) {
561 Some(si) => si,
562 None => return Ok(0),
563 };
564 match si.get_skip_range_full(dim_id) {
565 Some((_, skip_count, _, _)) => Ok((skip_count * 256) as u32),
566 None => Ok(0),
567 }
568 })
569 }
570
571 fn decompose(&self) -> crate::query::QueryDecomposition {
572 crate::query::QueryDecomposition::SparseTerms(vec![crate::query::SparseTermQueryInfo {
573 field: self.field,
574 dim_id: self.dim_id,
575 weight: self.weight,
576 heap_factor: self.heap_factor,
577 combiner: self.combiner,
578 over_fetch_factor: self.over_fetch_factor,
579 }])
580 }
581}
582
583struct SparseTermScorer<'a> {
588 cursor: crate::query::TermCursor<'a>,
589 field_id: u32,
590}
591
592impl crate::query::docset::DocSet for SparseTermScorer<'_> {
593 fn doc(&self) -> DocId {
594 let d = self.cursor.doc();
595 if d == u32::MAX { TERMINATED } else { d }
596 }
597
598 fn advance(&mut self) -> DocId {
599 match self.cursor.advance_sync() {
600 Ok(d) if d == u32::MAX => TERMINATED,
601 Ok(d) => d,
602 Err(_) => TERMINATED,
603 }
604 }
605
606 fn seek(&mut self, target: DocId) -> DocId {
607 match self.cursor.seek_sync(target) {
608 Ok(d) if d == u32::MAX => TERMINATED,
609 Ok(d) => d,
610 Err(_) => TERMINATED,
611 }
612 }
613
614 fn size_hint(&self) -> u32 {
615 0
616 }
617}
618
619impl Scorer for SparseTermScorer<'_> {
620 fn score(&self) -> Score {
621 self.cursor.score()
622 }
623
624 fn matched_positions(&self) -> Option<MatchedPositions> {
625 let ordinal = self.cursor.ordinal();
626 let score = self.cursor.score();
627 if score == 0.0 {
628 return None;
629 }
630 Some(vec![(
631 self.field_id,
632 vec![ScoredPosition::new(ordinal as u32, score)],
633 )])
634 }
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640 use crate::dsl::Field;
641
642 #[test]
643 fn test_sparse_vector_query_new() {
644 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
645 let query = SparseVectorQuery::new(Field(0), sparse.clone());
646
647 assert_eq!(query.field, Field(0));
648 assert_eq!(query.vector, sparse);
649 }
650
651 #[test]
652 fn test_sparse_vector_query_from_indices_weights() {
653 let query =
654 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
655
656 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
657 }
658}