use std::collections::HashMap;
use candle_core::{DType, Device, Tensor};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use trit_vsa::{PackedTritVec, Trit, vsa as trit_vsa_ops};
use crate::config::VSAConfig;
use crate::error::{OptimError, Result};
#[derive(Debug, Clone)]
pub struct GradientMetadata {
pub key_index: usize,
pub scale: f32,
pub shape: Vec<usize>,
}
pub struct VSAGradientCompressor {
config: VSAConfig,
param_count: usize,
hypervector_dim: usize,
key_cache: HashMap<usize, PackedTritVec>,
projection_cache: HashMap<usize, Tensor>,
}
impl VSAGradientCompressor {
#[allow(clippy::cast_possible_truncation)]
#[allow(clippy::cast_sign_loss)]
#[must_use]
pub fn new(param_count: usize, config: VSAConfig) -> Self {
let hypervector_dim = config.dimension.max(
(param_count as f32 * config.compression_ratio).max(256.0) as usize
);
Self {
config,
param_count,
hypervector_dim,
key_cache: HashMap::new(),
projection_cache: HashMap::new(),
}
}
#[must_use]
pub const fn compressed_dim(&self) -> usize {
self.hypervector_dim
}
fn get_binding_key(&mut self, index: usize) -> PackedTritVec {
if let Some(key) = self.key_cache.get(&index) {
return key.clone();
}
let seed = self.config.seed.wrapping_add(index as u64 * 12345);
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut key = PackedTritVec::new(self.hypervector_dim);
for i in 0..self.hypervector_dim {
use rand::Rng;
let r: f32 = rng.gen();
let trit = if r < 0.33 {
Trit::N
} else if r < 0.66 {
Trit::Z
} else {
Trit::P
};
key.set(i, trit);
}
self.key_cache.insert(index, key.clone());
key
}
fn get_projection(&mut self, grad_size: usize, device: &Device) -> Result<Tensor> {
if let Some(proj) = self.projection_cache.get(&grad_size) {
return Ok(proj.clone());
}
let seed = self.config.seed.wrapping_add(grad_size as u64 * 54321);
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let scale = 1.0 / (self.hypervector_dim as f32).sqrt();
let data: Vec<f32> = (0..grad_size * self.hypervector_dim)
.map(|_| {
use rand::Rng;
let r: f32 = rng.gen();
if r < 0.16 {
scale * 3.0_f32.sqrt() } else if r < 0.32 {
-scale * 3.0_f32.sqrt()
} else {
0.0
}
})
.collect();
let proj = Tensor::from_vec(data, (grad_size, self.hypervector_dim), device)?;
self.projection_cache.insert(grad_size, proj.clone());
Ok(proj)
}
fn project_to_hypervector(
&mut self,
gradient: &Tensor,
) -> Result<(PackedTritVec, f32)> {
let device = gradient.device();
let flat = gradient.flatten_all()?.to_dtype(DType::F32)?;
let grad_size = flat.elem_count();
let proj = self.get_projection(grad_size, device)?;
let projected = flat.unsqueeze(0)?.matmul(&proj)?.squeeze(0)?;
let data: Vec<f32> = projected.to_vec1()?;
let scale = if data.is_empty() {
0.0
} else {
data.iter().map(|v| v.abs()).sum::<f32>() / data.len() as f32
};
let mut packed = PackedTritVec::new(self.hypervector_dim);
if scale > 0.0 {
for (i, &v) in data.iter().enumerate() {
let trit = if v > scale {
Trit::P
} else if v < -scale {
Trit::N
} else {
Trit::Z
};
packed.set(i, trit);
}
}
Ok((packed, scale))
}
pub fn compress(
&mut self,
gradients: &HashMap<String, Tensor>,
) -> Result<(PackedTritVec, HashMap<String, GradientMetadata>)> {
if gradients.is_empty() {
return Err(OptimError::EmptyInput("No gradients to compress".to_string()));
}
let mut metadata = HashMap::new();
let mut bound_vectors: Vec<PackedTritVec> = Vec::new();
for (index, (name, grad)) in gradients.iter().enumerate() {
let (projected, scale) = self.project_to_hypervector(grad)?;
let key = self.get_binding_key(index);
let bound = trit_vsa_ops::bind(&projected, &key);
bound_vectors.push(bound);
metadata.insert(
name.clone(),
GradientMetadata {
key_index: index,
scale,
shape: grad.dims().to_vec(),
},
);
}
let refs: Vec<&PackedTritVec> = bound_vectors.iter().collect();
let bundled = trit_vsa_ops::bundle_many(&refs);
Ok((bundled, metadata))
}
pub fn decompress(
&mut self,
bundled: &PackedTritVec,
metadata: &HashMap<String, GradientMetadata>,
) -> Result<HashMap<String, Tensor>> {
let device = Device::Cpu; let mut gradients = HashMap::new();
for (name, meta) in metadata {
let key = self.get_binding_key(meta.key_index);
let unbound = trit_vsa_ops::unbind(bundled, &key);
let grad_size: usize = meta.shape.iter().product();
let proj = self.get_projection(grad_size, &device)?;
let unbound_float: Vec<f32> = (0..self.hypervector_dim)
.map(|i| unbound.get(i).value() as f32 * meta.scale)
.collect();
let unbound_tensor = Tensor::from_vec(
unbound_float,
self.hypervector_dim,
&device,
)?;
let reconstructed = unbound_tensor.unsqueeze(0)?
.matmul(&proj.t()?)?
.squeeze(0)?;
let grad = reconstructed.reshape(meta.shape.as_slice())?;
gradients.insert(name.clone(), grad);
}
Ok(gradients)
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn get_compression_stats(&self) -> CompressionStats {
CompressionStats {
original_params: self.param_count,
compressed_dim: self.hypervector_dim,
compression_ratio: self.hypervector_dim as f32 / self.param_count as f32,
memory_saving: 1.0 - (self.hypervector_dim as f32 * 2.0 / 32.0) / self.param_count as f32,
}
}
pub fn clear_cache(&mut self) {
self.key_cache.clear();
self.projection_cache.clear();
}
}
#[derive(Debug, Clone)]
pub struct CompressionStats {
pub original_params: usize,
pub compressed_dim: usize,
pub compression_ratio: f32,
pub memory_saving: f32,
}
impl std::fmt::Display for CompressionStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Compression: {} → {} ({:.1}% saved)",
self.original_params,
self.compressed_dim,
self.memory_saving * 100.0
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_mock_gradients(device: &Device) -> HashMap<String, Tensor> {
let mut gradients = HashMap::new();
gradients.insert(
"layer1.weight".to_string(),
Tensor::randn(0.0f32, 1.0, (64, 128), device).unwrap(),
);
gradients.insert(
"layer1.bias".to_string(),
Tensor::randn(0.0f32, 1.0, 64, device).unwrap(),
);
gradients.insert(
"layer2.weight".to_string(),
Tensor::randn(0.0f32, 1.0, (32, 64), device).unwrap(),
);
gradients
}
#[test]
fn test_compressor_creation() {
let compressor = VSAGradientCompressor::new(1_000_000, VSAConfig::default());
assert!(compressor.compressed_dim() >= 256);
}
#[test]
fn test_compress_decompress_roundtrip() {
let device = Device::Cpu;
let gradients = create_mock_gradients(&device);
let param_count: usize = gradients.values().map(|g| g.elem_count()).sum();
let mut compressor = VSAGradientCompressor::new(
param_count,
VSAConfig::default().with_compression_ratio(0.5),
);
let (bundled, metadata) = compressor.compress(&gradients).unwrap();
assert_eq!(bundled.len(), compressor.compressed_dim());
assert_eq!(metadata.len(), 3);
let reconstructed = compressor.decompress(&bundled, &metadata).unwrap();
assert_eq!(reconstructed.len(), 3);
for (name, orig) in &gradients {
let recon = reconstructed.get(name).unwrap();
assert_eq!(orig.dims(), recon.dims());
}
}
#[test]
fn test_compression_stats() {
let compressor = VSAGradientCompressor::new(1_000_000, VSAConfig::default());
let stats = compressor.get_compression_stats();
assert_eq!(stats.original_params, 1_000_000);
assert!(stats.memory_saving > 0.9);
}
#[test]
fn test_direction_preservation() {
let device = Device::Cpu;
let gradients = create_mock_gradients(&device);
let param_count: usize = gradients.values().map(|g| g.elem_count()).sum();
let mut compressor = VSAGradientCompressor::new(
param_count,
VSAConfig::default()
.with_dimension(8192) .with_compression_ratio(0.5),
);
let (bundled, metadata) = compressor.compress(&gradients).unwrap();
let reconstructed = compressor.decompress(&bundled, &metadata).unwrap();
for (name, orig) in &gradients {
let recon = reconstructed.get(name).unwrap();
let orig_flat = orig.flatten_all().unwrap();
let recon_flat = recon.flatten_all().unwrap();
let orig_data: Vec<f32> = orig_flat.to_vec1().unwrap();
let recon_data: Vec<f32> = recon_flat.to_vec1().unwrap();
let dot: f32 = orig_data.iter().zip(recon_data.iter()).map(|(a, b)| a * b).sum();
let norm_orig: f32 = orig_data.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_recon: f32 = recon_data.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_orig < 1e-6 || norm_recon < 1e-6 {
continue;
}
let cosine = dot / (norm_orig * norm_recon + 1e-8);
if orig.elem_count() >= 1024 {
assert!(
cosine > 0.1, "Gradient direction not preserved for {name}: cosine = {cosine}"
);
}
}
}
#[test]
fn test_bind_unbind_property() {
let mut compressor = VSAGradientCompressor::new(1000, VSAConfig::default().with_dimension(1024));
let key0 = compressor.get_binding_key(0);
let key1 = compressor.get_binding_key(1);
let mut same_count = 0;
for i in 0..key0.len() {
if key0.get(i) == key1.get(i) {
same_count += 1;
}
}
assert!(same_count < key0.len() * 2 / 3);
let test_vec = key0.clone();
let bound = trit_vsa_ops::bind(&test_vec, &key1);
let recovered = trit_vsa_ops::unbind(&bound, &key1);
for i in 0..test_vec.len() {
assert_eq!(test_vec.get(i), recovered.get(i));
}
}
#[test]
fn test_empty_gradients() {
let mut compressor = VSAGradientCompressor::new(1000, VSAConfig::default());
let gradients = HashMap::new();
let result = compressor.compress(&gradients);
assert!(result.is_err());
}
}