use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum SparseFormat {
#[default]
MaxScore,
Bmp,
}
impl SparseFormat {
fn is_default(&self) -> bool {
*self == Self::MaxScore
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[repr(u8)]
pub enum IndexSize {
U16 = 0,
#[default]
U32 = 1,
}
impl IndexSize {
pub fn bytes(&self) -> usize {
match self {
IndexSize::U16 => 2,
IndexSize::U32 => 4,
}
}
pub fn max_value(&self) -> u32 {
match self {
IndexSize::U16 => u16::MAX as u32,
IndexSize::U32 => u32::MAX,
}
}
pub(crate) fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(IndexSize::U16),
1 => Some(IndexSize::U32),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[repr(u8)]
pub enum WeightQuantization {
#[default]
Float32 = 0,
Float16 = 1,
UInt8 = 2,
UInt4 = 3,
}
impl WeightQuantization {
pub fn bytes_per_weight(&self) -> f32 {
match self {
WeightQuantization::Float32 => 4.0,
WeightQuantization::Float16 => 2.0,
WeightQuantization::UInt8 => 1.0,
WeightQuantization::UInt4 => 0.5,
}
}
pub(crate) fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(WeightQuantization::Float32),
1 => Some(WeightQuantization::Float16),
2 => Some(WeightQuantization::UInt8),
3 => Some(WeightQuantization::UInt4),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum QueryWeighting {
#[default]
One,
Idf,
IdfFile,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SparseQueryConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tokenizer: Option<String>,
#[serde(default)]
pub weighting: QueryWeighting,
#[serde(default = "default_heap_factor")]
pub heap_factor: f32,
#[serde(default)]
pub weight_threshold: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_query_dims: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pruning: Option<f32>,
#[serde(default = "default_min_terms")]
pub min_query_dims: usize,
#[serde(default)]
pub max_superblocks: usize,
}
fn default_heap_factor() -> f32 {
1.0
}
impl Default for SparseQueryConfig {
fn default() -> Self {
Self {
tokenizer: None,
weighting: QueryWeighting::One,
heap_factor: 1.0,
weight_threshold: 0.0,
max_query_dims: None,
pruning: None,
min_query_dims: 4,
max_superblocks: 0,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SparseVectorConfig {
#[serde(default, skip_serializing_if = "SparseFormat::is_default")]
pub format: SparseFormat,
pub index_size: IndexSize,
pub weight_quantization: WeightQuantization,
#[serde(default)]
pub weight_threshold: f32,
#[serde(default = "default_block_size")]
pub block_size: usize,
#[serde(default = "default_bmp_block_size")]
pub bmp_block_size: u32,
#[serde(default = "default_max_bmp_grid_bytes")]
pub max_bmp_grid_bytes: u64,
#[serde(default = "default_bmp_superblock_size")]
pub bmp_superblock_size: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub pruning: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub query_config: Option<SparseQueryConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dims: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_weight: Option<f32>,
#[serde(default = "default_min_terms")]
pub min_terms: usize,
}
fn default_block_size() -> usize {
128
}
fn default_bmp_block_size() -> u32 {
64
}
fn default_max_bmp_grid_bytes() -> u64 {
0 }
fn default_bmp_superblock_size() -> u32 {
64
}
fn default_min_terms() -> usize {
4
}
impl Default for SparseVectorConfig {
fn default() -> Self {
Self {
format: SparseFormat::MaxScore,
index_size: IndexSize::U32,
weight_quantization: WeightQuantization::Float32,
weight_threshold: 0.0,
block_size: 128,
bmp_block_size: 64,
max_bmp_grid_bytes: 0,
bmp_superblock_size: 64,
pruning: None,
query_config: None,
dims: None,
max_weight: None,
min_terms: 4,
}
}
}
impl SparseVectorConfig {
pub fn splade() -> Self {
Self {
format: SparseFormat::MaxScore,
index_size: IndexSize::U16,
weight_quantization: WeightQuantization::UInt8,
weight_threshold: 0.01, block_size: 128,
bmp_block_size: 64,
max_bmp_grid_bytes: 0,
bmp_superblock_size: 64,
pruning: Some(0.1), query_config: Some(SparseQueryConfig {
tokenizer: None,
weighting: QueryWeighting::One,
heap_factor: 0.8, weight_threshold: 0.01, max_query_dims: Some(20), pruning: Some(0.1), min_query_dims: 4,
max_superblocks: 0,
}),
dims: None,
max_weight: None,
min_terms: 4,
}
}
pub fn splade_bmp() -> Self {
Self {
format: SparseFormat::Bmp,
index_size: IndexSize::U16,
weight_quantization: WeightQuantization::UInt8,
weight_threshold: 0.01,
block_size: 128,
bmp_block_size: 64,
max_bmp_grid_bytes: 0,
bmp_superblock_size: 64,
pruning: Some(0.1),
query_config: Some(SparseQueryConfig {
tokenizer: None,
weighting: QueryWeighting::One,
heap_factor: 0.8,
weight_threshold: 0.01,
max_query_dims: Some(20),
pruning: Some(0.1),
min_query_dims: 4,
max_superblocks: 0,
}),
dims: Some(105879),
max_weight: Some(5.0),
min_terms: 4,
}
}
pub fn compact() -> Self {
Self {
format: SparseFormat::MaxScore,
index_size: IndexSize::U16,
weight_quantization: WeightQuantization::UInt4,
weight_threshold: 0.02, block_size: 128,
bmp_block_size: 64,
max_bmp_grid_bytes: 0,
bmp_superblock_size: 64,
pruning: Some(0.15), query_config: Some(SparseQueryConfig {
tokenizer: None,
weighting: QueryWeighting::One,
heap_factor: 0.7, weight_threshold: 0.02, max_query_dims: Some(15), pruning: Some(0.15), min_query_dims: 4,
max_superblocks: 0,
}),
dims: None,
max_weight: None,
min_terms: 4,
}
}
pub fn full_precision() -> Self {
Self {
format: SparseFormat::MaxScore,
index_size: IndexSize::U32,
weight_quantization: WeightQuantization::Float32,
weight_threshold: 0.0,
block_size: 128,
bmp_block_size: 64,
max_bmp_grid_bytes: 0,
bmp_superblock_size: 64,
pruning: None,
query_config: None,
dims: None,
max_weight: None,
min_terms: 4,
}
}
pub fn conservative() -> Self {
Self {
format: SparseFormat::MaxScore,
index_size: IndexSize::U32,
weight_quantization: WeightQuantization::Float16,
weight_threshold: 0.005, block_size: 128,
bmp_block_size: 64,
max_bmp_grid_bytes: 0,
bmp_superblock_size: 64,
pruning: None, query_config: Some(SparseQueryConfig {
tokenizer: None,
weighting: QueryWeighting::One,
heap_factor: 0.9, weight_threshold: 0.005, max_query_dims: Some(50), pruning: None, min_query_dims: 4,
max_superblocks: 0,
}),
dims: None,
max_weight: None,
min_terms: 4,
}
}
pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
self.weight_threshold = threshold;
self
}
pub fn with_pruning(mut self, fraction: f32) -> Self {
self.pruning = Some(fraction.clamp(0.0, 1.0));
self
}
pub fn bytes_per_entry(&self) -> f32 {
self.index_size.bytes() as f32 + self.weight_quantization.bytes_per_weight()
}
pub fn to_byte(&self) -> u8 {
let format_bit = if self.format == SparseFormat::Bmp {
0x08
} else {
0
};
((self.index_size as u8) << 4) | format_bit | (self.weight_quantization as u8)
}
pub fn from_byte(b: u8) -> Option<Self> {
let index_size = IndexSize::from_u8((b >> 4) & 0x03)?;
let format = if b & 0x08 != 0 {
SparseFormat::Bmp
} else {
SparseFormat::MaxScore
};
let weight_quantization = WeightQuantization::from_u8(b & 0x07)?;
Some(Self {
format,
index_size,
weight_quantization,
weight_threshold: 0.0,
block_size: 128,
bmp_block_size: 64,
max_bmp_grid_bytes: 0,
bmp_superblock_size: 64,
pruning: None,
query_config: None,
dims: None,
max_weight: None,
min_terms: 4,
})
}
pub fn with_block_size(mut self, size: usize) -> Self {
self.block_size = size.next_power_of_two();
self
}
pub fn with_query_config(mut self, config: SparseQueryConfig) -> Self {
self.query_config = Some(config);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct SparseEntry {
pub dim_id: u32,
pub weight: f32,
}
#[derive(Debug, Clone, Default)]
pub struct SparseVector {
pub(super) entries: Vec<SparseEntry>,
}
impl SparseVector {
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
entries: Vec::with_capacity(capacity),
}
}
pub fn from_entries(dim_ids: &[u32], weights: &[f32]) -> Self {
assert_eq!(dim_ids.len(), weights.len());
let mut entries: Vec<SparseEntry> = dim_ids
.iter()
.zip(weights.iter())
.map(|(&dim_id, &weight)| SparseEntry { dim_id, weight })
.collect();
entries.sort_by_key(|e| e.dim_id);
Self { entries }
}
pub fn push(&mut self, dim_id: u32, weight: f32) {
debug_assert!(
self.entries.is_empty() || self.entries.last().unwrap().dim_id < dim_id,
"Entries must be added in sorted order by dim_id"
);
self.entries.push(SparseEntry { dim_id, weight });
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &SparseEntry> {
self.entries.iter()
}
pub fn sort_by_dim(&mut self) {
self.entries.sort_by_key(|e| e.dim_id);
}
pub fn sort_by_weight_desc(&mut self) {
self.entries.sort_by(|a, b| {
b.weight
.partial_cmp(&a.weight)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
pub fn top_k(&self, k: usize) -> Vec<SparseEntry> {
let mut sorted = self.entries.clone();
sorted.sort_by(|a, b| {
b.weight
.partial_cmp(&a.weight)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted.truncate(k);
sorted
}
pub fn dot(&self, other: &SparseVector) -> f32 {
let mut result = 0.0f32;
let mut i = 0;
let mut j = 0;
while i < self.entries.len() && j < other.entries.len() {
let a = &self.entries[i];
let b = &other.entries[j];
match a.dim_id.cmp(&b.dim_id) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
result += a.weight * b.weight;
i += 1;
j += 1;
}
}
}
result
}
pub fn norm_squared(&self) -> f32 {
self.entries.iter().map(|e| e.weight * e.weight).sum()
}
pub fn norm(&self) -> f32 {
self.norm_squared().sqrt()
}
pub fn filter_by_weight(&self, min_weight: f32) -> Self {
let entries: Vec<SparseEntry> = self
.entries
.iter()
.filter(|e| e.weight.abs() >= min_weight)
.cloned()
.collect();
Self { entries }
}
}
impl From<Vec<(u32, f32)>> for SparseVector {
fn from(pairs: Vec<(u32, f32)>) -> Self {
Self {
entries: pairs
.into_iter()
.map(|(dim_id, weight)| SparseEntry { dim_id, weight })
.collect(),
}
}
}
impl From<SparseVector> for Vec<(u32, f32)> {
fn from(vec: SparseVector) -> Self {
vec.entries
.into_iter()
.map(|e| (e.dim_id, e.weight))
.collect()
}
}