use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ActivationType {
Log1p,
Relu,
Log1pRelu,
}
#[derive(Debug, Clone)]
pub struct SpladeConfig {
pub vocab_size: usize,
pub activation: ActivationType,
pub sparsity_threshold: f32,
pub max_terms: Option<usize>,
pub model_path: Option<String>,
}
impl Default for SpladeConfig {
fn default() -> Self {
Self {
vocab_size: 30522, activation: ActivationType::Log1pRelu,
sparsity_threshold: 0.01,
max_terms: Some(256),
model_path: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparseVector {
pub indices: Vec<usize>,
pub weights: Vec<f32>,
pub dim: usize,
}
impl SparseVector {
pub fn new(indices: Vec<usize>, weights: Vec<f32>, dim: usize) -> Self {
assert_eq!(indices.len(), weights.len());
Self {
indices,
weights,
dim,
}
}
pub fn dot(&self, other: &SparseVector) -> f32 {
let mut score = 0.0;
let other_map: HashMap<usize, f32> = other
.indices
.iter()
.zip(other.weights.iter())
.map(|(&idx, &w)| (idx, w))
.collect();
for (&idx, &weight) in self.indices.iter().zip(self.weights.iter()) {
if let Some(&other_weight) = other_map.get(&idx) {
score += weight * other_weight;
}
}
score
}
pub fn l1_norm(&self) -> f32 {
self.weights.iter().map(|w| w.abs()).sum()
}
pub fn l2_norm(&self) -> f32 {
self.weights.iter().map(|w| w * w).sum::<f32>().sqrt()
}
pub fn prune(&mut self, threshold: f32) {
let keep: Vec<usize> = self
.weights
.iter()
.enumerate()
.filter(|(_, &w)| w >= threshold)
.map(|(i, _)| i)
.collect();
let new_indices: Vec<usize> = keep.iter().map(|&i| self.indices[i]).collect();
let new_weights: Vec<f32> = keep.iter().map(|&i| self.weights[i]).collect();
self.indices = new_indices;
self.weights = new_weights;
}
pub fn top_k(&mut self, k: usize) {
if self.indices.len() <= k {
return;
}
let mut pairs: Vec<(usize, f32)> = self
.indices
.iter()
.zip(self.weights.iter())
.map(|(&i, &w)| (i, w))
.collect();
pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
pairs.truncate(k);
self.indices = pairs.iter().map(|&(i, _)| i).collect();
self.weights = pairs.iter().map(|&(_, w)| w).collect();
}
pub fn sparsity(&self) -> f32 {
1.0 - (self.indices.len() as f32 / self.dim as f32)
}
pub fn to_dense(&self) -> Vec<f32> {
let mut dense = vec![0.0; self.dim];
for (&idx, &weight) in self.indices.iter().zip(self.weights.iter()) {
if idx < self.dim {
dense[idx] = weight;
}
}
dense
}
}
pub struct SpladeEncoder {
config: SpladeConfig,
}
impl SpladeEncoder {
pub fn new(config: SpladeConfig) -> Result<Self> {
Ok(Self { config })
}
pub fn encode(&self, text: &str) -> Result<SparseVector> {
let mock_indices = vec![100, 250, 500, 1000, 2000];
let mock_weights = vec![2.5, 1.8, 1.2, 0.9, 0.5];
let mut sparse = SparseVector::new(mock_indices, mock_weights, self.config.vocab_size);
sparse.prune(self.config.sparsity_threshold);
if let Some(max_terms) = self.config.max_terms {
sparse.top_k(max_terms);
}
Ok(sparse)
}
pub fn encode_batch(&self, texts: &[&str]) -> Result<Vec<SparseVector>> {
texts.iter().map(|text| self.encode(text)).collect()
}
fn apply_activation(&self, logits: &[f32]) -> Vec<f32> {
match self.config.activation {
ActivationType::Log1p => logits.iter().map(|&x| (1.0 + x).ln()).collect(),
ActivationType::Relu => logits.iter().map(|&x| x.max(0.0)).collect(),
ActivationType::Log1pRelu => logits.iter().map(|&x| (1.0 + x.max(0.0)).ln()).collect(),
}
}
}
pub struct SparseIndex {
inverted_index: HashMap<usize, Vec<(String, f32)>>,
documents: HashMap<String, SparseVector>,
}
impl SparseIndex {
pub fn new() -> Self {
Self {
inverted_index: HashMap::new(),
documents: HashMap::new(),
}
}
pub fn add(&mut self, doc_id: String, vector: SparseVector) {
for (&term_id, &weight) in vector.indices.iter().zip(vector.weights.iter()) {
self.inverted_index
.entry(term_id)
.or_insert_with(Vec::new)
.push((doc_id.clone(), weight));
}
self.documents.insert(doc_id, vector);
}
pub fn search(&self, query: &SparseVector, k: usize) -> Vec<(String, f32)> {
let mut scores: HashMap<String, f32> = HashMap::new();
for (&term_id, &query_weight) in query.indices.iter().zip(query.weights.iter()) {
if let Some(postings) = self.inverted_index.get(&term_id) {
for (doc_id, doc_weight) in postings {
*scores.entry(doc_id.clone()).or_insert(0.0) += query_weight * doc_weight;
}
}
}
let mut results: Vec<(String, f32)> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results.truncate(k);
results
}
pub fn stats(&self) -> SparseIndexStats {
let total_terms: usize = self.inverted_index.len();
let total_postings: usize = self.inverted_index.values().map(|v| v.len()).sum();
let avg_sparsity = if !self.documents.is_empty() {
self.documents.values().map(|v| v.sparsity()).sum::<f32>() / self.documents.len() as f32
} else {
0.0
};
SparseIndexStats {
num_documents: self.documents.len(),
num_unique_terms: total_terms,
total_postings,
avg_sparsity,
}
}
}
impl Default for SparseIndex {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SparseIndexStats {
pub num_documents: usize,
pub num_unique_terms: usize,
pub total_postings: usize,
pub avg_sparsity: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_vector_creation() {
let indices = vec![10, 20, 30];
let weights = vec![1.0, 2.0, 3.0];
let sparse = SparseVector::new(indices.clone(), weights.clone(), 100);
assert_eq!(sparse.indices, indices);
assert_eq!(sparse.weights, weights);
assert_eq!(sparse.dim, 100);
}
#[test]
fn test_sparse_dot_product() {
let v1 = SparseVector::new(vec![0, 2, 4], vec![1.0, 2.0, 3.0], 10);
let v2 = SparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0], 10);
let dot = v1.dot(&v2);
assert_eq!(dot, 4.0);
}
#[test]
fn test_sparse_norms() {
let sparse = SparseVector::new(vec![0, 1, 2], vec![3.0, 4.0, 0.0], 10);
assert_eq!(sparse.l1_norm(), 7.0);
assert_eq!(sparse.l2_norm(), 5.0); }
#[test]
fn test_sparse_pruning() {
let mut sparse = SparseVector::new(vec![0, 1, 2, 3], vec![2.0, 0.5, 1.5, 0.1], 10);
sparse.prune(1.0);
assert_eq!(sparse.indices, vec![0, 2]);
assert_eq!(sparse.weights, vec![2.0, 1.5]);
}
#[test]
fn test_sparse_top_k() {
let mut sparse = SparseVector::new(vec![0, 1, 2, 3, 4], vec![1.0, 5.0, 2.0, 4.0, 3.0], 10);
sparse.top_k(3);
assert_eq!(sparse.indices.len(), 3);
assert_eq!(sparse.weights.len(), 3);
assert!(sparse.weights.contains(&5.0));
assert!(sparse.weights.contains(&4.0));
assert!(sparse.weights.contains(&3.0));
}
#[test]
fn test_sparse_sparsity() {
let sparse = SparseVector::new(vec![0, 1], vec![1.0, 2.0], 100);
assert_eq!(sparse.sparsity(), 0.98); }
#[test]
fn test_sparse_to_dense() {
let sparse = SparseVector::new(vec![0, 2, 4], vec![1.0, 2.0, 3.0], 5);
let dense = sparse.to_dense();
assert_eq!(dense, vec![1.0, 0.0, 2.0, 0.0, 3.0]);
}
#[test]
fn test_splade_encoder_creation() {
let config = SpladeConfig::default();
let encoder = SpladeEncoder::new(config);
assert!(encoder.is_ok());
}
#[test]
fn test_splade_encode() {
let config = SpladeConfig::default();
let encoder = SpladeEncoder::new(config).unwrap();
let sparse = encoder.encode("test query").unwrap();
assert!(sparse.indices.len() > 0);
assert_eq!(sparse.indices.len(), sparse.weights.len());
}
#[test]
fn test_sparse_index_add_and_search() {
let mut index = SparseIndex::new();
let doc1 = SparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 1.0], 100);
let doc2 = SparseVector::new(vec![2, 3, 4], vec![1.5, 1.0, 2.0], 100);
index.add("doc1".to_string(), doc1);
index.add("doc2".to_string(), doc2);
let query = SparseVector::new(vec![2, 3], vec![1.0, 1.0], 100);
let results = index.search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "doc1");
}
#[test]
fn test_sparse_index_stats() {
let mut index = SparseIndex::new();
let doc1 = SparseVector::new(vec![1, 2], vec![1.0, 2.0], 100);
let doc2 = SparseVector::new(vec![2, 3], vec![1.5, 1.0], 100);
index.add("doc1".to_string(), doc1);
index.add("doc2".to_string(), doc2);
let stats = index.stats();
assert_eq!(stats.num_documents, 2);
assert!(stats.num_unique_terms >= 2);
assert!(stats.avg_sparsity > 0.9); }
#[test]
fn test_activation_log1p() {
let config = SpladeConfig {
activation: ActivationType::Log1p,
..Default::default()
};
let encoder = SpladeEncoder::new(config).unwrap();
let result = encoder.apply_activation(&[0.0, 1.0, 2.0]);
assert!((result[0] - 0.0).abs() < 0.001); assert!((result[1] - 0.693).abs() < 0.01); }
}