mod compressor;
pub use compressor::{CompressionStats, GradientMetadata, VSAGradientCompressor};
use candle_core::Tensor;
use trit_vsa::{PackedTritVec, Trit};
use crate::error::{OptimError, Result};
pub fn hyperdimensional_bind(a: &Tensor, b: &Tensor) -> Result<Tensor> {
if a.dims() != b.dims() {
return Err(OptimError::DimensionMismatch {
expected: a.dims().iter().product(),
actual: b.dims().iter().product(),
});
}
Ok(a.mul(b)?)
}
pub fn hyperdimensional_bind_ternary(a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec> {
if a.len() != b.len() {
return Err(OptimError::DimensionMismatch {
expected: a.len(),
actual: b.len(),
});
}
Ok(trit_vsa::vsa::bind(a, b))
}
pub fn hyperdimensional_bundle(vectors: &[Tensor], weights: Option<&[f32]>) -> Result<Tensor> {
if vectors.is_empty() {
return Err(OptimError::EmptyInput("Cannot bundle empty list".to_string()));
}
let dim = vectors[0].dims();
for v in vectors.iter().skip(1) {
if v.dims() != dim {
return Err(OptimError::ShapeMismatch {
expected: dim.to_vec(),
actual: v.dims().to_vec(),
});
}
}
let n = vectors.len();
let weights: Vec<f32> = weights
.map(|w| w.to_vec())
.unwrap_or_else(|| vec![1.0; n]);
if weights.len() != n {
return Err(OptimError::DimensionMismatch {
expected: n,
actual: weights.len(),
});
}
let mut result = vectors[0].zeros_like()?;
for (v, &w) in vectors.iter().zip(weights.iter()) {
let weighted = (v * w as f64)?;
result = result.add(&weighted)?;
}
Ok((result / n as f64)?)
}
pub fn hyperdimensional_bundle_ternary(vectors: &[&PackedTritVec]) -> Result<PackedTritVec> {
if vectors.is_empty() {
return Err(OptimError::EmptyInput("Cannot bundle empty list".to_string()));
}
Ok(trit_vsa::vsa::bundle_many(vectors))
}
#[allow(clippy::cast_possible_truncation)]
pub fn ternary_quantize(x: &Tensor, scale: Option<f32>) -> Result<(Tensor, f32)> {
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let flat = x_f32.flatten_all()?;
let data: Vec<f32> = flat.to_vec1()?;
let scale = scale.unwrap_or_else(|| {
if data.is_empty() {
0.0
} else {
data.iter().map(|v| v.abs()).sum::<f32>() / data.len() as f32
}
});
if scale == 0.0 {
return Ok((x.zeros_like()?, 0.0));
}
let quantized: Vec<f32> = data
.iter()
.map(|&v| {
if v > scale {
1.0
} else if v < -scale {
-1.0
} else {
0.0
}
})
.collect();
let result = Tensor::from_vec(quantized, x.shape(), x.device())?;
Ok((result, scale))
}
#[allow(clippy::cast_possible_truncation)]
pub fn ternary_quantize_packed(x: &Tensor, scale: Option<f32>) -> Result<(PackedTritVec, f32)> {
let x_f32 = x.to_dtype(candle_core::DType::F32)?;
let flat = x_f32.flatten_all()?;
let data: Vec<f32> = flat.to_vec1()?;
let scale = scale.unwrap_or_else(|| {
if data.is_empty() {
0.0
} else {
data.iter().map(|v| v.abs()).sum::<f32>() / data.len() as f32
}
});
let mut packed = PackedTritVec::new(data.len());
if scale == 0.0 {
return Ok((packed, 0.0));
}
for (i, &v) in data.iter().enumerate() {
let trit = if v > scale {
Trit::P
} else if v < -scale {
Trit::N
} else {
Trit::Z
};
packed.set(i, trit);
}
Ok((packed, scale))
}
#[allow(clippy::cast_precision_loss)]
pub fn ternary_dequantize_packed(
packed: &PackedTritVec,
scale: f32,
shape: &[usize],
device: &candle_core::Device,
) -> Result<Tensor> {
let data: Vec<f32> = (0..packed.len())
.map(|i| packed.get(i).value() as f32 * scale)
.collect();
Ok(Tensor::from_vec(data, shape, device)?)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_hyperdimensional_bind_shape() {
let device = Device::Cpu;
let a = Tensor::randn(0.0f32, 1.0, 100, &device).unwrap();
let b = Tensor::randn(0.0f32, 1.0, 100, &device).unwrap();
let result = hyperdimensional_bind(&a, &b).unwrap();
assert_eq!(result.dims(), a.dims());
}
#[test]
fn test_hyperdimensional_bind_dimension_mismatch() {
let device = Device::Cpu;
let a = Tensor::randn(0.0f32, 1.0, 100, &device).unwrap();
let b = Tensor::randn(0.0f32, 1.0, 50, &device).unwrap();
assert!(hyperdimensional_bind(&a, &b).is_err());
}
#[test]
fn test_hyperdimensional_bundle_shape() {
let device = Device::Cpu;
let vectors: Vec<Tensor> = (0..5)
.map(|_| Tensor::randn(0.0f32, 1.0, 100, &device).unwrap())
.collect();
let result = hyperdimensional_bundle(&vectors, None).unwrap();
assert_eq!(result.dims(), vectors[0].dims());
}
#[test]
fn test_hyperdimensional_bundle_empty() {
let result = hyperdimensional_bundle(&[], None);
assert!(result.is_err());
}
#[test]
fn test_ternary_quantize_values() {
let device = Device::Cpu;
let x = Tensor::randn(0.0f32, 1.0, 1000, &device).unwrap();
let (quantized, scale) = ternary_quantize(&x, None).unwrap();
assert!(scale > 0.0);
let values: Vec<f32> = quantized.flatten_all().unwrap().to_vec1().unwrap();
for v in values {
assert!(v == -1.0 || v == 0.0 || v == 1.0, "Unexpected value: {v}");
}
}
#[test]
fn test_ternary_quantize_zeros() {
let device = Device::Cpu;
let x = Tensor::zeros(100, candle_core::DType::F32, &device).unwrap();
let (quantized, scale) = ternary_quantize(&x, None).unwrap();
assert_eq!(scale, 0.0);
let values: Vec<f32> = quantized.flatten_all().unwrap().to_vec1().unwrap();
assert!(values.iter().all(|&v| v == 0.0));
}
#[test]
fn test_ternary_quantize_packed_roundtrip() {
let device = Device::Cpu;
let x = Tensor::randn(0.0f32, 1.0, (10, 10), &device).unwrap();
let (packed, scale) = ternary_quantize_packed(&x, None).unwrap();
assert_eq!(packed.len(), 100);
let restored = ternary_dequantize_packed(&packed, scale, &[10, 10], &device).unwrap();
assert_eq!(restored.dims(), x.dims());
}
#[test]
fn test_bind_ternary() {
let mut a = PackedTritVec::new(100);
let mut b = PackedTritVec::new(100);
for i in 0..100 {
a.set(i, if i % 3 == 0 { Trit::P } else { Trit::Z });
b.set(i, if i % 2 == 0 { Trit::N } else { Trit::P });
}
let bound = hyperdimensional_bind_ternary(&a, &b).unwrap();
assert_eq!(bound.len(), 100);
}
}