1use serde::{Deserialize, Serialize};
11
12pub const SPARSE_KEY: &str = "__quiver_sparse__";
14
15pub const DEFAULT_RRF_K0: f32 = 60.0;
17
18#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
21pub struct SparseVector {
22 pub indices: Vec<u32>,
24 pub values: Vec<f32>,
26}
27
28impl SparseVector {
29 pub fn validate(&self) -> Result<(), String> {
31 if self.indices.len() != self.values.len() {
32 return Err(format!(
33 "sparse vector indices ({}) and values ({}) length mismatch",
34 self.indices.len(),
35 self.values.len()
36 ));
37 }
38 let mut seen = self.indices.clone();
39 seen.sort_unstable();
40 if seen.windows(2).any(|w| w[0] == w[1]) {
41 return Err("sparse vector has duplicate indices".to_owned());
42 }
43 Ok(())
44 }
45
46 pub fn len(&self) -> usize {
48 self.indices.len()
49 }
50
51 pub fn is_empty(&self) -> bool {
53 self.indices.is_empty()
54 }
55
56 pub fn normalized(&self) -> SparseVector {
59 let mut pairs: Vec<(u32, f32)> = self
60 .indices
61 .iter()
62 .copied()
63 .zip(self.values.iter().copied())
64 .collect();
65 pairs.sort_by_key(|&(i, _)| i);
66 SparseVector {
67 indices: pairs.iter().map(|&(i, _)| i).collect(),
68 values: pairs.iter().map(|&(_, v)| v).collect(),
69 }
70 }
71
72 pub fn dot(&self, other: &SparseVector) -> f32 {
75 use std::collections::HashMap;
76 let lhs: HashMap<u32, f32> = self
77 .indices
78 .iter()
79 .copied()
80 .zip(self.values.iter().copied())
81 .collect();
82 let mut sum = 0.0f32;
83 for (i, v) in other.indices.iter().zip(other.values.iter()) {
84 if let Some(w) = lhs.get(i) {
85 sum += w * v;
86 }
87 }
88 sum
89 }
90}
91
92pub fn rrf_fuse(rankings: &[Vec<String>], k0: f32, top_k: usize) -> Vec<(String, f32)> {
100 use std::collections::HashMap;
101 let mut scores: HashMap<String, f32> = HashMap::new();
102 for ranking in rankings {
103 for (rank, id) in ranking.iter().enumerate() {
104 *scores.entry(id.clone()).or_insert(0.0) += 1.0 / (k0 + rank as f32 + 1.0);
105 }
106 }
107 let mut fused: Vec<(String, f32)> = scores.into_iter().collect();
108 fused.sort_by(|a, b| {
109 b.1.partial_cmp(&a.1)
110 .unwrap_or(std::cmp::Ordering::Equal)
111 .then_with(|| a.0.cmp(&b.0))
112 });
113 fused.truncate(top_k);
114 fused
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn validate_catches_length_mismatch_and_dupes() {
123 assert!(
124 SparseVector {
125 indices: vec![1, 2],
126 values: vec![1.0]
127 }
128 .validate()
129 .is_err()
130 );
131 assert!(
132 SparseVector {
133 indices: vec![1, 1],
134 values: vec![1.0, 2.0]
135 }
136 .validate()
137 .is_err()
138 );
139 assert!(
140 SparseVector {
141 indices: vec![3, 1, 2],
142 values: vec![1.0, 2.0, 3.0]
143 }
144 .validate()
145 .is_ok()
146 );
147 }
148
149 #[test]
150 fn dot_is_order_independent_and_uses_shared_dims() {
151 let a = SparseVector {
152 indices: vec![1, 5, 9],
153 values: vec![1.0, 2.0, 3.0],
154 };
155 let b = SparseVector {
156 indices: vec![9, 1, 7],
157 values: vec![10.0, 4.0, 1.0],
158 };
159 assert_eq!(a.dot(&b), 34.0);
161 assert_eq!(a.dot(&b), b.dot(&a));
162 }
163
164 #[test]
165 fn normalized_sorts_indices_keeping_values_parallel() {
166 let n = SparseVector {
167 indices: vec![5, 1, 3],
168 values: vec![50.0, 10.0, 30.0],
169 }
170 .normalized();
171 assert_eq!(n.indices, vec![1, 3, 5]);
172 assert_eq!(n.values, vec![10.0, 30.0, 50.0]);
173 }
174
175 #[test]
176 fn rrf_rewards_agreement_across_lists() {
177 let dense = vec!["a".to_owned(), "b".to_owned(), "c".to_owned()];
178 let sparse = vec!["b".to_owned(), "a".to_owned(), "d".to_owned()];
179 let fused = rrf_fuse(&[dense, sparse], DEFAULT_RRF_K0, 10);
180 let ids: Vec<&str> = fused.iter().map(|(id, _)| id.as_str()).collect();
182 assert_eq!(&ids[..2], &["a", "b"]);
183 assert!((fused[0].1 - fused[1].1).abs() < 1e-9);
185 assert!(fused[2].1 < fused[0].1);
187 }
188
189 #[test]
190 fn rrf_truncates_to_top_k() {
191 let r = vec!["a".to_owned(), "b".to_owned(), "c".to_owned()];
192 assert_eq!(rrf_fuse(&[r], DEFAULT_RRF_K0, 2).len(), 2);
193 }
194}