#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
use crate::hyperdim::HVec10240;
#[derive(Debug, Clone)]
pub struct ProjectionConfig {
pub seed: u64,
pub native_dim: usize,
pub target_dim: usize,
pub sparsity: f32,
}
impl Default for ProjectionConfig {
fn default() -> Self {
Self {
seed: 42,
native_dim: 384, target_dim: 10240,
sparsity: 2.0 / 3.0,
}
}
}
#[derive(Debug, Clone)]
pub struct Projection {
entries: Vec<(usize, usize, i8)>,
native_dim: usize,
}
impl Projection {
#[must_use]
pub const fn empty() -> Self {
Self {
entries: Vec::new(),
native_dim: 0,
}
}
#[must_use]
pub fn new(config: &ProjectionConfig) -> Self {
use rand::RngExt;
use rand::SeedableRng;
use rand::rngs::StdRng;
let mut rng = StdRng::seed_from_u64(config.seed);
let mut entries = Vec::new();
for row in 0..config.target_dim {
for col in 0..config.native_dim {
let r: f32 = rng.random();
if r < config.sparsity {
continue;
}
let value: i8 = if rng.random_bool(0.5) { 1 } else { -1 };
entries.push((row, col, value));
}
}
Self {
entries,
native_dim: config.native_dim,
}
}
pub fn project(&self, vec: &[f32]) -> HVec10240 {
assert!(vec.len() == self.native_dim, "input dimension mismatch");
let mut sums = vec![0.0_f32; 10240];
for &(row, col, value) in &self.entries {
sums[row] += value as f32 * vec[col];
}
let mut hv = HVec10240::zero();
for (i, &sum) in sums.iter().enumerate() {
if sum >= 0.0 {
let word = i / 128;
let bit = i % 128;
hv.data[word] |= 1u128 << bit;
}
}
hv
}
#[must_use]
pub fn nnz(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn sparsity_ratio(&self) -> f32 {
let total = 10240 * self.native_dim;
self.entries.len() as f32 / total as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn projection_sparsity_is_correct() {
let config = ProjectionConfig {
seed: 42,
native_dim: 384,
target_dim: 10240,
sparsity: 2.0 / 3.0,
};
let proj = Projection::new(&config);
let expected_nnz = 10240 * 384 / 3;
let actual = proj.nnz();
assert!(
actual > expected_nnz * 9 / 10 && actual < expected_nnz * 11 / 10,
"nnz {actual} not close to expected {expected_nnz}"
);
}
#[test]
fn projection_is_deterministic() {
let config = ProjectionConfig::default();
let p1 = Projection::new(&config);
let p2 = Projection::new(&config);
assert_eq!(p1.entries, p2.entries);
}
#[test]
fn projection_preserves_similarity() {
use crate::hyperdim::HVec10240;
let config = ProjectionConfig {
seed: 42,
native_dim: 384,
target_dim: 10240,
sparsity: 2.0 / 3.0,
};
let proj = Projection::new(&config);
let v1 = vec![0.1_f32; 384];
let mut v2 = vec![0.1_f32; 384];
for i in 0..38 {
v2[i] = 0.2; }
let h1 = proj.project(&v1);
let h2 = proj.project(&v2);
let sim = HVec10240::cosine_similarity(&h1, &h2);
assert!(sim > 0.5, "similarity {sim} too low after projection");
}
#[test]
fn projection_empty_works() {
let proj = Projection::empty();
assert_eq!(proj.nnz(), 0);
assert_eq!(proj.native_dim, 0);
}
#[test]
fn projection_accuracy_preservation() {
let config = ProjectionConfig {
seed: 123,
native_dim: 1536,
target_dim: 10240,
sparsity: 2.0 / 3.0,
};
let proj = Projection::new(&config);
let mut v1 = vec![0.0; 1536];
let mut v2 = vec![0.0; 1536];
for i in 0..1536 {
v1[i] = (i as f32).sin();
v2[i] = (i as f32).sin() + 0.1;
}
let h1 = proj.project(&v1);
let h2 = proj.project(&v2);
let sim = h1.cosine_similarity(&h2);
assert!(
sim >= 0.9,
"Projection cosine similarity {sim} should be high for similar pairs"
);
}
}