use serde::{Deserialize, Serialize};
pub const SPARSE_KEY: &str = "__quiver_sparse__";
pub const DEFAULT_RRF_K0: f32 = 60.0;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SparseVector {
pub indices: Vec<u32>,
pub values: Vec<f32>,
}
impl SparseVector {
pub fn validate(&self) -> Result<(), String> {
if self.indices.len() != self.values.len() {
return Err(format!(
"sparse vector indices ({}) and values ({}) length mismatch",
self.indices.len(),
self.values.len()
));
}
let mut seen = self.indices.clone();
seen.sort_unstable();
if seen.windows(2).any(|w| w[0] == w[1]) {
return Err("sparse vector has duplicate indices".to_owned());
}
Ok(())
}
pub fn len(&self) -> usize {
self.indices.len()
}
pub fn is_empty(&self) -> bool {
self.indices.is_empty()
}
pub fn normalized(&self) -> SparseVector {
let mut pairs: Vec<(u32, f32)> = self
.indices
.iter()
.copied()
.zip(self.values.iter().copied())
.collect();
pairs.sort_by_key(|&(i, _)| i);
SparseVector {
indices: pairs.iter().map(|&(i, _)| i).collect(),
values: pairs.iter().map(|&(_, v)| v).collect(),
}
}
pub fn dot(&self, other: &SparseVector) -> f32 {
use std::collections::HashMap;
let lhs: HashMap<u32, f32> = self
.indices
.iter()
.copied()
.zip(self.values.iter().copied())
.collect();
let mut sum = 0.0f32;
for (i, v) in other.indices.iter().zip(other.values.iter()) {
if let Some(w) = lhs.get(i) {
sum += w * v;
}
}
sum
}
}
pub fn rrf_fuse(rankings: &[Vec<String>], k0: f32, top_k: usize) -> Vec<(String, f32)> {
use std::collections::HashMap;
let mut scores: HashMap<String, f32> = HashMap::new();
for ranking in rankings {
for (rank, id) in ranking.iter().enumerate() {
*scores.entry(id.clone()).or_insert(0.0) += 1.0 / (k0 + rank as f32 + 1.0);
}
}
let mut fused: Vec<(String, f32)> = scores.into_iter().collect();
fused.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
fused.truncate(top_k);
fused
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_catches_length_mismatch_and_dupes() {
assert!(
SparseVector {
indices: vec![1, 2],
values: vec![1.0]
}
.validate()
.is_err()
);
assert!(
SparseVector {
indices: vec![1, 1],
values: vec![1.0, 2.0]
}
.validate()
.is_err()
);
assert!(
SparseVector {
indices: vec![3, 1, 2],
values: vec![1.0, 2.0, 3.0]
}
.validate()
.is_ok()
);
}
#[test]
fn dot_is_order_independent_and_uses_shared_dims() {
let a = SparseVector {
indices: vec![1, 5, 9],
values: vec![1.0, 2.0, 3.0],
};
let b = SparseVector {
indices: vec![9, 1, 7],
values: vec![10.0, 4.0, 1.0],
};
assert_eq!(a.dot(&b), 34.0);
assert_eq!(a.dot(&b), b.dot(&a));
}
#[test]
fn normalized_sorts_indices_keeping_values_parallel() {
let n = SparseVector {
indices: vec![5, 1, 3],
values: vec![50.0, 10.0, 30.0],
}
.normalized();
assert_eq!(n.indices, vec![1, 3, 5]);
assert_eq!(n.values, vec![10.0, 30.0, 50.0]);
}
#[test]
fn rrf_rewards_agreement_across_lists() {
let dense = vec!["a".to_owned(), "b".to_owned(), "c".to_owned()];
let sparse = vec!["b".to_owned(), "a".to_owned(), "d".to_owned()];
let fused = rrf_fuse(&[dense, sparse], DEFAULT_RRF_K0, 10);
let ids: Vec<&str> = fused.iter().map(|(id, _)| id.as_str()).collect();
assert_eq!(&ids[..2], &["a", "b"]);
assert!((fused[0].1 - fused[1].1).abs() < 1e-9);
assert!(fused[2].1 < fused[0].1);
}
#[test]
fn rrf_truncates_to_top_k() {
let r = vec!["a".to_owned(), "b".to_owned(), "c".to_owned()];
assert_eq!(rrf_fuse(&[r], DEFAULT_RRF_K0, 2).len(), 2);
}
}