use serde::{Deserialize, Serialize};
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum FdeError {
#[error("muvera: token dimension {got} != configured input_dim {expected}")]
DimensionMismatch { got: usize, expected: usize },
#[error("muvera: invalid params: {0}")]
InvalidParams(String),
}
pub const DEFAULT_FDE_SEED: u64 = 0x9E37_79B9_7F4A_7C15;
const MAX_K_SIM: u32 = 16;
const MAX_FDE_DIM: usize = 200_000;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FdeParams {
pub k_sim: u32,
pub reps: u32,
pub d_proj: u32,
pub input_dim: u32,
pub seed: u64,
}
impl FdeParams {
#[inline]
pub fn proj_dim(&self) -> usize {
if self.d_proj == 0 {
self.input_dim as usize
} else {
self.d_proj as usize
}
}
#[inline]
pub fn buckets(&self) -> usize {
1usize << self.k_sim
}
#[inline]
pub fn fde_dim(&self) -> usize {
self.reps as usize * self.buckets() * self.proj_dim()
}
pub fn validate(&self) -> Result<(), FdeError> {
if self.k_sim == 0 || self.k_sim > MAX_K_SIM {
return Err(FdeError::InvalidParams(format!(
"k_sim must be in 1..={MAX_K_SIM}, got {}",
self.k_sim
)));
}
if self.reps == 0 {
return Err(FdeError::InvalidParams("reps must be >= 1".to_string()));
}
if self.input_dim == 0 {
return Err(FdeError::InvalidParams(
"input_dim must be >= 1".to_string(),
));
}
let dim = self.fde_dim();
if dim == 0 || dim > MAX_FDE_DIM {
return Err(FdeError::InvalidParams(format!(
"fde_dim {dim} out of range (1..={MAX_FDE_DIM}); reduce k_sim/reps/d_proj"
)));
}
Ok(())
}
}
struct SplitMix64 {
state: u64,
}
impl SplitMix64 {
#[inline]
fn new(seed: u64) -> Self {
Self { state: seed }
}
#[inline]
fn next_u64(&mut self) -> u64 {
self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
#[inline]
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
#[inline]
fn next_gaussian(&mut self) -> f32 {
let u1 = self.next_f64().max(1e-12);
let u2 = self.next_f64();
let r = (-2.0 * u1.ln()).sqrt();
(r * (2.0 * std::f64::consts::PI * u2).cos()) as f32
}
}
#[inline]
fn rep_seed(base: u64, rep: u32) -> u64 {
let mut s = base.wrapping_add((rep as u64).wrapping_mul(0xD1B5_4A32_D192_ED03));
s = (s ^ (s >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
s = (s ^ (s >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
s ^ (s >> 31)
}
struct RepMatrices {
hyperplanes: Vec<f32>,
projection: Option<Vec<f32>>,
}
impl RepMatrices {
fn build(params: &FdeParams, rep: u32) -> Self {
let mut rng = SplitMix64::new(rep_seed(params.seed, rep));
let d = params.input_dim as usize;
let hyperplanes = (0..params.k_sim as usize * d)
.map(|_| rng.next_gaussian())
.collect();
let projection = if params.d_proj == 0 {
None
} else {
let pd = params.d_proj as usize;
let scale = 1.0f32 / (pd as f32).sqrt();
let proj = (0..pd * d)
.map(|_| {
if rng.next_u64() & 1 == 0 {
scale
} else {
-scale
}
})
.collect();
Some(proj)
};
Self {
hyperplanes,
projection,
}
}
#[inline]
fn bucket_of(&self, token: &[f32], k_sim: u32, d: usize) -> usize {
let mut bucket = 0usize;
for h in 0..k_sim as usize {
let row = &self.hyperplanes[h * d..(h + 1) * d];
let mut dot = 0.0f32;
for i in 0..d {
dot += row[i] * token[i];
}
if dot > 0.0 {
bucket |= 1 << h;
}
}
bucket
}
#[inline]
fn project(&self, token: &[f32], proj_dim: usize, d: usize) -> Vec<f32> {
match &self.projection {
None => token.to_vec(),
Some(p) => {
let mut out = vec![0.0f32; proj_dim];
for (r, slot) in out.iter_mut().enumerate() {
let row = &p[r * d..(r + 1) * d];
let mut acc = 0.0f32;
for i in 0..d {
acc += row[i] * token[i];
}
*slot = acc;
}
out
}
}
}
}
pub struct FdeEncoder {
params: FdeParams,
reps: Vec<RepMatrices>,
}
impl FdeEncoder {
pub fn new(params: &FdeParams) -> Result<Self, FdeError> {
params.validate()?;
let reps = (0..params.reps)
.map(|r| RepMatrices::build(params, r))
.collect();
Ok(Self {
params: params.clone(),
reps,
})
}
#[inline]
pub fn params(&self) -> &FdeParams {
&self.params
}
#[inline]
pub fn fde_dim(&self) -> usize {
self.params.fde_dim()
}
fn check_tokens(&self, tokens: &[Vec<f32>]) -> Result<(), FdeError> {
let d = self.params.input_dim as usize;
for tok in tokens {
if tok.len() != d {
return Err(FdeError::DimensionMismatch {
got: tok.len(),
expected: d,
});
}
}
Ok(())
}
pub fn encode_doc(&self, tokens: &[Vec<f32>]) -> Result<Vec<f32>, FdeError> {
self.check_tokens(tokens)?;
let pd = self.params.proj_dim();
let b = self.params.buckets();
let d = self.params.input_dim as usize;
let mut out = vec![0.0f32; self.params.fde_dim()];
for (ri, rep) in self.reps.iter().enumerate() {
let base = ri * b * pd;
let mut sums = vec![0.0f32; b * pd];
let mut counts = vec![0u32; b];
for tok in tokens {
let bk = rep.bucket_of(tok, self.params.k_sim, d);
let proj = rep.project(tok, pd, d);
let slot = &mut sums[bk * pd..(bk + 1) * pd];
for (s, p) in slot.iter_mut().zip(proj.iter()) {
*s += *p;
}
counts[bk] += 1;
}
for bk in 0..b {
if counts[bk] > 0 {
let inv = 1.0f32 / counts[bk] as f32;
let dst = &mut out[base + bk * pd..base + (bk + 1) * pd];
let src = &sums[bk * pd..(bk + 1) * pd];
for (o, s) in dst.iter_mut().zip(src.iter()) {
*o = *s * inv;
}
}
}
for bk in 0..b {
if counts[bk] == 0
&& let Some(src) = nearest_nonempty(bk, &counts)
{
let (lo, hi) = (bk.min(src), bk.max(src));
let (left, right) = out[base..base + b * pd].split_at_mut(hi * pd);
let (src_slice, dst_slice) = if bk == lo {
(&right[0..pd], &mut left[bk * pd..bk * pd + pd])
} else {
(&left[src * pd..src * pd + pd], &mut right[0..pd])
};
dst_slice.copy_from_slice(src_slice);
}
}
}
Ok(out)
}
pub fn encode_query(&self, tokens: &[Vec<f32>]) -> Result<Vec<f32>, FdeError> {
self.check_tokens(tokens)?;
let pd = self.params.proj_dim();
let b = self.params.buckets();
let d = self.params.input_dim as usize;
let mut out = vec![0.0f32; self.params.fde_dim()];
for (ri, rep) in self.reps.iter().enumerate() {
let base = ri * b * pd;
for tok in tokens {
let bk = rep.bucket_of(tok, self.params.k_sim, d);
let proj = rep.project(tok, pd, d);
let dst = &mut out[base + bk * pd..base + (bk + 1) * pd];
for (o, p) in dst.iter_mut().zip(proj.iter()) {
*o += *p;
}
}
}
Ok(out)
}
}
#[inline]
fn nearest_nonempty(bucket: usize, counts: &[u32]) -> Option<usize> {
let mut best: Option<(u32, usize)> = None;
for (cand, &c) in counts.iter().enumerate() {
if c > 0 {
let h = (bucket ^ cand).count_ones();
match best {
Some((bh, _)) if h >= bh => {}
_ => best = Some((h, cand)),
}
}
}
best.map(|(_, idx)| idx)
}
pub fn encode_doc(tokens: &[Vec<f32>], params: &FdeParams) -> Result<Vec<f32>, FdeError> {
FdeEncoder::new(params)?.encode_doc(tokens)
}
pub fn encode_query(tokens: &[Vec<f32>], params: &FdeParams) -> Result<Vec<f32>, FdeError> {
FdeEncoder::new(params)?.encode_query(tokens)
}
#[cfg(test)]
mod tests {
use super::*;
fn maxsim_dot(query: &[Vec<f32>], doc: &[Vec<f32>]) -> f32 {
query
.iter()
.map(|q| {
if doc.is_empty() {
0.0
} else {
doc.iter()
.map(|d| dot(q, d))
.fold(f32::NEG_INFINITY, f32::max)
}
})
.sum()
}
struct Gen(SplitMix64);
impl Gen {
fn new(seed: u64) -> Self {
Self(SplitMix64::new(seed))
}
fn unit_token(&mut self, dim: usize) -> Vec<f32> {
let mut v: Vec<f32> = (0..dim).map(|_| self.0.next_gaussian()).collect();
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
for x in &mut v {
*x /= norm;
}
v
}
fn multivec(&mut self, n: usize, dim: usize) -> Vec<Vec<f32>> {
(0..n).map(|_| self.unit_token(dim)).collect()
}
fn count(&mut self, lo: usize, hi: usize) -> usize {
lo + (self.0.next_u64() as usize) % (hi - lo + 1)
}
}
fn params(k_sim: u32, reps: u32, d_proj: u32, input_dim: u32) -> FdeParams {
FdeParams {
k_sim,
reps,
d_proj,
input_dim,
seed: DEFAULT_FDE_SEED,
}
}
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| x * y).sum()
}
fn pearson(xs: &[f32], ys: &[f32]) -> f32 {
let n = xs.len() as f32;
let mx = xs.iter().sum::<f32>() / n;
let my = ys.iter().sum::<f32>() / n;
let mut cov = 0.0;
let mut vx = 0.0;
let mut vy = 0.0;
for (x, y) in xs.iter().zip(ys) {
let dx = x - mx;
let dy = y - my;
cov += dx * dy;
vx += dx * dx;
vy += dy * dy;
}
cov / (vx.sqrt() * vy.sqrt()).max(1e-12)
}
#[test]
fn fde_dim_arithmetic() {
assert_eq!(params(4, 20, 16, 96).fde_dim(), 20 * 16 * 16);
assert_eq!(params(3, 2, 0, 8).fde_dim(), 2 * 8 * 8);
assert_eq!(params(4, 20, 16, 96).buckets(), 16);
}
#[test]
fn validate_rejects_bad_params() {
assert!(params(0, 1, 0, 8).validate().is_err()); assert!(params(MAX_K_SIM + 1, 1, 0, 8).validate().is_err());
assert!(params(4, 0, 0, 8).validate().is_err()); assert!(params(4, 1, 0, 0).validate().is_err()); assert!(params(16, 1000, 64, 96).validate().is_err());
assert!(params(4, 20, 16, 96).validate().is_ok());
}
#[test]
fn fde_self_retrieval_ranks_first() {
let dim = 32usize;
let p = params(4, 20, 16, dim as u32); let enc = FdeEncoder::new(&p).unwrap();
let mut g = Gen::new(7);
let corpus: Vec<Vec<Vec<f32>>> = (0..50)
.map(|_| {
let n = g.count(4, 16);
g.multivec(n, dim)
})
.collect();
let dfde: Vec<Vec<f32>> = corpus.iter().map(|d| enc.encode_doc(d).unwrap()).collect();
for (j, d) in corpus.iter().enumerate() {
let fq = enc.encode_query(d).unwrap();
let top = (0..corpus.len())
.max_by(|&a, &b| dot(&fq, &dfde[a]).total_cmp(&dot(&fq, &dfde[b])))
.unwrap();
assert_eq!(top, j, "doc {j} did not self-retrieve as FDE top-1");
}
}
#[test]
fn fde_dot_positively_correlates_with_maxsim() {
let dim = 32usize;
let p = params(4, 24, 16, dim as u32);
let enc = FdeEncoder::new(&p).unwrap();
let mut g = Gen::new(42);
let n_pairs = 400;
let mut fde_scores = Vec::with_capacity(n_pairs);
let mut exact_scores = Vec::with_capacity(n_pairs);
for _ in 0..n_pairs {
let (qn, dn) = (g.count(2, 6), g.count(4, 16));
let q = g.multivec(qn, dim);
let d = g.multivec(dn, dim);
fde_scores.push(dot(
&enc.encode_query(&q).unwrap(),
&enc.encode_doc(&d).unwrap(),
));
exact_scores.push(maxsim_dot(&q, &d));
}
let r = pearson(&fde_scores, &exact_scores);
assert!(r >= 0.55, "FDE/MaxSim correlation regressed: {r}");
}
#[test]
fn deterministic_across_rebuild() {
let p = params(4, 8, 8, 16);
let e1 = FdeEncoder::new(&p).unwrap();
let e2 = FdeEncoder::new(&p).unwrap();
let mut g = Gen::new(7);
let d = g.multivec(10, 16);
assert_eq!(e1.encode_doc(&d).unwrap(), e2.encode_doc(&d).unwrap());
let q = g.multivec(3, 16);
assert_eq!(e1.encode_query(&q).unwrap(), e2.encode_query(&q).unwrap());
}
#[test]
fn different_seed_changes_output() {
let mut p = params(4, 8, 8, 16);
let e1 = FdeEncoder::new(&p).unwrap();
p.seed = DEFAULT_FDE_SEED ^ 0xDEAD_BEEF;
let e2 = FdeEncoder::new(&p).unwrap();
let mut g = Gen::new(11);
let d = g.multivec(10, 16);
assert_ne!(e1.encode_doc(&d).unwrap(), e2.encode_doc(&d).unwrap());
}
#[test]
fn empty_doc_is_all_zero() {
let p = params(4, 4, 8, 16);
let enc = FdeEncoder::new(&p).unwrap();
let fde = enc.encode_doc(&[]).unwrap();
assert_eq!(fde.len(), p.fde_dim());
assert!(fde.iter().all(|&x| x == 0.0));
}
#[test]
fn empty_query_scores_zero() {
let p = params(4, 4, 8, 16);
let enc = FdeEncoder::new(&p).unwrap();
let mut g = Gen::new(3);
let fq = enc.encode_query(&[]).unwrap();
let fd = enc.encode_doc(&g.multivec(8, 16)).unwrap();
assert_eq!(dot(&fq, &fd), 0.0);
}
#[test]
fn dim_mismatch_errors() {
let p = params(4, 4, 8, 16);
let enc = FdeEncoder::new(&p).unwrap();
let bad = vec![vec![1.0f32; 15]]; assert_eq!(
enc.encode_doc(&bad),
Err(FdeError::DimensionMismatch {
got: 15,
expected: 16
})
);
assert!(enc.encode_query(&bad).is_err());
}
#[test]
fn single_token_doc_fills_all_buckets() {
let p = params(3, 1, 0, 8); let enc = FdeEncoder::new(&p).unwrap();
let mut g = Gen::new(99);
let tok = g.unit_token(8);
let fde = enc.encode_doc(&[tok]).unwrap();
let pd = p.proj_dim();
let first = &fde[0..pd];
for bk in 1..p.buckets() {
assert_eq!(&fde[bk * pd..(bk + 1) * pd], first, "bucket {bk} differs");
}
assert!(first.iter().any(|&x| x != 0.0));
}
#[test]
fn query_leaves_empty_buckets_zero() {
let p = params(3, 1, 0, 8);
let enc = FdeEncoder::new(&p).unwrap();
let mut g = Gen::new(123);
let tok = g.unit_token(8);
let fde = enc.encode_query(&[tok]).unwrap();
let pd = p.proj_dim();
let nonzero_buckets = (0..p.buckets())
.filter(|&bk| fde[bk * pd..(bk + 1) * pd].iter().any(|&x| x != 0.0))
.count();
assert_eq!(nonzero_buckets, 1);
}
#[test]
fn free_fns_match_encoder() {
let p = params(4, 4, 8, 16);
let enc = FdeEncoder::new(&p).unwrap();
let mut g = Gen::new(55);
let d = g.multivec(6, 16);
assert_eq!(encode_doc(&d, &p).unwrap(), enc.encode_doc(&d).unwrap());
let q = g.multivec(2, 16);
assert_eq!(encode_query(&q, &p).unwrap(), enc.encode_query(&q).unwrap());
}
}