#![allow(clippy::doc_markdown)]
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum SparseError {
#[error("network error: {0}")]
Network(String),
#[error("config error: {0}")]
Config(String),
#[error("inference error: {0}")]
Inference(String),
#[error("empty input")]
EmptyInput,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct SparseEmbed {
pub indices: Vec<u32>,
pub values: Vec<f32>,
pub vocab_id: String,
}
impl SparseEmbed {
pub fn new(
indices: Vec<u32>,
values: Vec<f32>,
vocab_id: impl Into<String>,
) -> Result<Self, SparseError> {
if indices.len() != values.len() {
return Err(SparseError::Config(format!(
"indices.len() {} != values.len() {}",
indices.len(),
values.len()
)));
}
for w in indices.windows(2) {
if w[0] >= w[1] {
return Err(SparseError::Config(format!(
"indices must be strictly ascending; saw {} then {}",
w[0], w[1]
)));
}
}
Ok(Self {
indices,
values,
vocab_id: vocab_id.into(),
})
}
pub fn from_unsorted(
pairs: impl IntoIterator<Item = (u32, f32)>,
vocab_id: impl Into<String>,
) -> Self {
use std::collections::BTreeMap;
let mut bucket: BTreeMap<u32, f32> = BTreeMap::new();
for (i, v) in pairs {
let e = bucket.entry(i).or_insert(f32::NEG_INFINITY);
if v > *e {
*e = v;
}
}
let (indices, values): (Vec<_>, Vec<_>) =
bucket.into_iter().filter(|(_, v)| *v > 0.0).unzip();
Self {
indices,
values,
vocab_id: vocab_id.into(),
}
}
#[must_use]
pub const fn nnz(&self) -> usize {
self.indices.len()
}
#[must_use]
pub fn dot(&self, other: &Self) -> Option<f32> {
if self.vocab_id != other.vocab_id {
return None;
}
let mut i = 0;
let mut j = 0;
let mut sum = 0.0f32;
while i < self.indices.len() && j < other.indices.len() {
use std::cmp::Ordering;
match self.indices[i].cmp(&other.indices[j]) {
Ordering::Less => i += 1,
Ordering::Greater => j += 1,
Ordering::Equal => {
sum += self.values[i] * other.values[j];
i += 1;
j += 1;
}
}
}
Some(sum)
}
}
pub trait SparseEncoder: Send + Sync + Debug {
fn model(&self) -> &str;
fn vocab_id(&self) -> &str;
fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError>;
fn encode_query(&self, text: &str) -> Result<SparseEmbed, SparseError> {
self.encode(text)
}
}
const FNV_OFFSET_BASIS_32: u32 = 2_166_136_261;
const FNV_PRIME_32: u32 = 16_777_619;
const MOCK_VOCAB_SIZE: u32 = 1024;
#[derive(Debug, Clone)]
pub struct MockSparseEncoder {
vocab_id: String,
}
impl Default for MockSparseEncoder {
fn default() -> Self {
Self {
vocab_id: "mock:1024".into(),
}
}
}
impl SparseEncoder for MockSparseEncoder {
fn model(&self) -> &str {
"mock:len-inverse"
}
fn vocab_id(&self) -> &str {
&self.vocab_id
}
fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError> {
if text.trim().is_empty() {
return Err(SparseError::EmptyInput);
}
let pairs = text.split_whitespace().map(|tok| {
let h: u32 = tok.bytes().fold(FNV_OFFSET_BASIS_32, |acc, b| {
acc.wrapping_mul(FNV_PRIME_32).wrapping_add(u32::from(b))
});
let idx = h % MOCK_VOCAB_SIZE;
let weight = 1.0f32 / (1.0 + tok.len() as f32);
(idx, weight)
});
Ok(SparseEmbed::from_unsorted(pairs, &self.vocab_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sparse_embed_rejects_unsorted_indices() {
let e = SparseEmbed::new(vec![5, 3], vec![0.5, 0.5], "v0").unwrap_err();
assert!(matches!(e, SparseError::Config(_)));
}
#[test]
fn sparse_embed_rejects_length_mismatch() {
let e = SparseEmbed::new(vec![1, 2], vec![0.5], "v0").unwrap_err();
assert!(matches!(e, SparseError::Config(_)));
}
#[test]
fn from_unsorted_sorts_and_max_pools() {
let s = SparseEmbed::from_unsorted([(5, 0.1), (3, 0.9), (5, 0.3), (1, 0.2)], "v0");
assert_eq!(s.indices, vec![1, 3, 5]);
assert!(
(s.values[2] - 0.3).abs() < 1e-6,
"max-pool should keep 0.3 for index 5"
);
}
#[test]
fn from_unsorted_drops_zero_weights() {
let s = SparseEmbed::from_unsorted([(1, 0.0), (2, 0.5), (3, -0.1)], "v0");
assert_eq!(s.indices, vec![2]);
}
#[test]
fn dot_product_on_disjoint_is_zero() {
let a = SparseEmbed::new(vec![1, 2], vec![1.0, 1.0], "v").unwrap();
let b = SparseEmbed::new(vec![3, 4], vec![1.0, 1.0], "v").unwrap();
assert_eq!(a.dot(&b), Some(0.0));
}
#[test]
fn dot_product_on_overlap() {
let a = SparseEmbed::new(vec![1, 2, 5], vec![0.5, 0.5, 0.2], "v").unwrap();
let b = SparseEmbed::new(vec![2, 5, 9], vec![0.4, 0.3, 0.1], "v").unwrap();
let d = a.dot(&b).unwrap();
assert!((d - 0.26).abs() < 1e-6, "got {d}");
}
#[test]
fn dot_product_different_vocabs_is_none() {
let a = SparseEmbed::new(vec![1], vec![1.0], "v0").unwrap();
let b = SparseEmbed::new(vec![1], vec![1.0], "v1").unwrap();
assert_eq!(a.dot(&b), None);
}
#[test]
fn mock_encoder_is_deterministic() {
let e = MockSparseEncoder::default();
let a = e.encode("hello world").unwrap();
let b = e.encode("hello world").unwrap();
assert_eq!(a, b);
}
#[test]
fn mock_encoder_empty_input_errors() {
let e = MockSparseEncoder::default();
assert!(matches!(
e.encode(" ").unwrap_err(),
SparseError::EmptyInput
));
}
#[test]
fn mock_encoder_vocab_id_carries_through() {
let e = MockSparseEncoder::default();
let emb = e.encode("hello").unwrap();
assert_eq!(emb.vocab_id, e.vocab_id());
}
}