Skip to main content

quiver_query/
sparse.rs

1// SPDX-License-Identifier: AGPL-3.0-only
2//! Sparse vectors and Reciprocal Rank Fusion for hybrid search (ADR-0043).
3//!
4//! A [`SparseVector`] is a learned-sparse (SPLADE/BGE-M3) or lexical term-weight
5//! vector — parallel `indices` (dimension ids) and `values` (weights). It rides
6//! in the point payload under [`SPARSE_KEY`] (no on-disk format change); the
7//! embeddable engine builds a derived inverted index from it. [`rrf_fuse`] merges
8//! the dense and sparse result lists by rank, the standard hybrid fuser.
9
10use serde::{Deserialize, Serialize};
11
12/// The reserved payload key carrying a point's sparse vector (ADR-0043).
13pub const SPARSE_KEY: &str = "__quiver_sparse__";
14
15/// The conventional RRF rank-bias constant (Cormack et al., 2009).
16pub const DEFAULT_RRF_K0: f32 = 60.0;
17
18/// A sparse vector: parallel `indices` and `values`. Indices are dimension ids
19/// into a (possibly very large) sparse vocabulary; values are their weights.
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub struct SparseVector {
22    /// Dimension ids. After [`SparseVector::normalized`] they are sorted and unique.
23    pub indices: Vec<u32>,
24    /// Per-index weights, parallel to `indices`.
25    pub values: Vec<f32>,
26}
27
28impl SparseVector {
29    /// Validate shape: equal-length, and no duplicate index after sorting.
30    pub fn validate(&self) -> Result<(), String> {
31        if self.indices.len() != self.values.len() {
32            return Err(format!(
33                "sparse vector indices ({}) and values ({}) length mismatch",
34                self.indices.len(),
35                self.values.len()
36            ));
37        }
38        let mut seen = self.indices.clone();
39        seen.sort_unstable();
40        if seen.windows(2).any(|w| w[0] == w[1]) {
41            return Err("sparse vector has duplicate indices".to_owned());
42        }
43        Ok(())
44    }
45
46    /// Number of non-zero terms.
47    pub fn len(&self) -> usize {
48        self.indices.len()
49    }
50
51    /// Whether the vector has no terms.
52    pub fn is_empty(&self) -> bool {
53        self.indices.is_empty()
54    }
55
56    /// Return a copy with indices sorted ascending (values kept parallel). The
57    /// canonical form the inverted index and `dot` assume.
58    pub fn normalized(&self) -> SparseVector {
59        let mut pairs: Vec<(u32, f32)> = self
60            .indices
61            .iter()
62            .copied()
63            .zip(self.values.iter().copied())
64            .collect();
65        pairs.sort_by_key(|&(i, _)| i);
66        SparseVector {
67            indices: pairs.iter().map(|&(i, _)| i).collect(),
68            values: pairs.iter().map(|&(_, v)| v).collect(),
69        }
70    }
71
72    /// Dot product with another sparse vector. Order-independent (builds a small
73    /// lookup over `self`), so callers need not pre-sort.
74    pub fn dot(&self, other: &SparseVector) -> f32 {
75        use std::collections::HashMap;
76        let lhs: HashMap<u32, f32> = self
77            .indices
78            .iter()
79            .copied()
80            .zip(self.values.iter().copied())
81            .collect();
82        let mut sum = 0.0f32;
83        for (i, v) in other.indices.iter().zip(other.values.iter()) {
84            if let Some(w) = lhs.get(i) {
85                sum += w * v;
86            }
87        }
88        sum
89    }
90}
91
92/// Fuse several ranked id lists by Reciprocal Rank Fusion and return the top
93/// `top_k` ids with their fused scores, highest first.
94///
95/// For each list, a document at 0-based `rank` contributes `1 / (k0 + rank + 1)`;
96/// the contributions sum across lists. RRF is rank-based, so the (incomparable)
97/// dense-distance and sparse-dot scales need no normalisation — the property that
98/// makes it the standard, robust hybrid fuser. Ties break by id for determinism.
99pub fn rrf_fuse(rankings: &[Vec<String>], k0: f32, top_k: usize) -> Vec<(String, f32)> {
100    use std::collections::HashMap;
101    let mut scores: HashMap<String, f32> = HashMap::new();
102    for ranking in rankings {
103        for (rank, id) in ranking.iter().enumerate() {
104            *scores.entry(id.clone()).or_insert(0.0) += 1.0 / (k0 + rank as f32 + 1.0);
105        }
106    }
107    let mut fused: Vec<(String, f32)> = scores.into_iter().collect();
108    fused.sort_by(|a, b| {
109        b.1.partial_cmp(&a.1)
110            .unwrap_or(std::cmp::Ordering::Equal)
111            .then_with(|| a.0.cmp(&b.0))
112    });
113    fused.truncate(top_k);
114    fused
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn validate_catches_length_mismatch_and_dupes() {
123        assert!(
124            SparseVector {
125                indices: vec![1, 2],
126                values: vec![1.0]
127            }
128            .validate()
129            .is_err()
130        );
131        assert!(
132            SparseVector {
133                indices: vec![1, 1],
134                values: vec![1.0, 2.0]
135            }
136            .validate()
137            .is_err()
138        );
139        assert!(
140            SparseVector {
141                indices: vec![3, 1, 2],
142                values: vec![1.0, 2.0, 3.0]
143            }
144            .validate()
145            .is_ok()
146        );
147    }
148
149    #[test]
150    fn dot_is_order_independent_and_uses_shared_dims() {
151        let a = SparseVector {
152            indices: vec![1, 5, 9],
153            values: vec![1.0, 2.0, 3.0],
154        };
155        let b = SparseVector {
156            indices: vec![9, 1, 7],
157            values: vec![10.0, 4.0, 1.0],
158        };
159        // shared dims: 1 (1*4) + 9 (3*10) = 34
160        assert_eq!(a.dot(&b), 34.0);
161        assert_eq!(a.dot(&b), b.dot(&a));
162    }
163
164    #[test]
165    fn normalized_sorts_indices_keeping_values_parallel() {
166        let n = SparseVector {
167            indices: vec![5, 1, 3],
168            values: vec![50.0, 10.0, 30.0],
169        }
170        .normalized();
171        assert_eq!(n.indices, vec![1, 3, 5]);
172        assert_eq!(n.values, vec![10.0, 30.0, 50.0]);
173    }
174
175    #[test]
176    fn rrf_rewards_agreement_across_lists() {
177        let dense = vec!["a".to_owned(), "b".to_owned(), "c".to_owned()];
178        let sparse = vec!["b".to_owned(), "a".to_owned(), "d".to_owned()];
179        let fused = rrf_fuse(&[dense, sparse], DEFAULT_RRF_K0, 10);
180        // "a" (ranks 0,1) and "b" (ranks 1,0) appear in both → top two; both equal.
181        let ids: Vec<&str> = fused.iter().map(|(id, _)| id.as_str()).collect();
182        assert_eq!(&ids[..2], &["a", "b"]);
183        // a: 1/61 + 1/62 ; b: 1/62 + 1/61 → equal, so id order breaks the tie.
184        assert!((fused[0].1 - fused[1].1).abs() < 1e-9);
185        // c and d (single-list) score below.
186        assert!(fused[2].1 < fused[0].1);
187    }
188
189    #[test]
190    fn rrf_truncates_to_top_k() {
191        let r = vec!["a".to_owned(), "b".to_owned(), "c".to_owned()];
192        assert_eq!(rrf_fuse(&[r], DEFAULT_RRF_K0, 2).len(), 2);
193    }
194}