use crate::RetrieveError;
#[derive(Clone, Debug)]
pub struct TernaryVector {
data: Vec<u8>,
dimension: usize,
positive_count: usize,
negative_count: usize,
original_norm: f32,
}
impl TernaryVector {
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn get(&self, idx: usize) -> i8 {
if idx >= self.dimension {
return 0;
}
let byte_idx = idx / 4;
let bit_offset = (idx % 4) * 2;
let bits = (self.data[byte_idx] >> bit_offset) & 0b11;
match bits {
0b00 => 0,
0b01 => 1,
0b10 => -1,
_ => 0, }
}
pub fn sparsity(&self) -> f32 {
let nonzero = self.positive_count + self.negative_count;
1.0 - (nonzero as f32 / self.dimension as f32)
}
pub fn memory_bytes(&self) -> usize {
self.data.len()
}
}
#[derive(Clone, Debug)]
pub struct TernaryConfig {
pub threshold_high: f32,
pub threshold_low: f32,
pub normalize: bool,
pub target_sparsity: Option<f32>,
}
impl Default for TernaryConfig {
fn default() -> Self {
Self {
threshold_high: 0.3,
threshold_low: -0.3,
normalize: true,
target_sparsity: None,
}
}
}
pub struct TernaryQuantizer {
config: TernaryConfig,
dimension: usize,
adaptive_thresholds: Option<Vec<(f32, f32)>>,
mean: Option<Vec<f32>>,
}
impl TernaryQuantizer {
pub fn new(dimension: usize, config: TernaryConfig) -> Self {
Self {
config,
dimension,
adaptive_thresholds: None,
mean: None,
}
}
pub fn with_dimension(dimension: usize) -> Self {
Self::new(dimension, TernaryConfig::default())
}
pub fn fit(&mut self, vectors: &[f32], num_vectors: usize) -> Result<(), RetrieveError> {
if vectors.len() != num_vectors * self.dimension {
return Err(RetrieveError::Other("Vector count mismatch".to_string()));
}
let mut mean = vec![0.0f32; self.dimension];
for i in 0..num_vectors {
let vec = &vectors[i * self.dimension..(i + 1) * self.dimension];
for (j, &v) in vec.iter().enumerate() {
mean[j] += v;
}
}
for m in &mut mean {
*m /= num_vectors as f32;
}
self.mean = Some(mean);
if let Some(target_sparsity) = self.config.target_sparsity {
let mut thresholds = Vec::with_capacity(self.dimension);
for d in 0..self.dimension {
let mut values: Vec<f32> = (0..num_vectors)
.map(|i| {
let v = vectors[i * self.dimension + d];
if let Some(ref m) = self.mean {
v - m[d]
} else {
v
}
})
.collect();
values.sort_by(|a, b| a.total_cmp(b));
let zero_fraction = target_sparsity;
let nonzero_fraction = (1.0 - zero_fraction) / 2.0;
let low_idx = (nonzero_fraction * num_vectors as f32) as usize;
let high_idx = ((1.0 - nonzero_fraction) * num_vectors as f32) as usize;
let low_idx = low_idx.min(num_vectors - 1);
let high_idx = high_idx.min(num_vectors - 1);
let threshold_low = values[low_idx];
let threshold_high = values[high_idx];
thresholds.push((threshold_low, threshold_high));
}
self.adaptive_thresholds = Some(thresholds);
}
Ok(())
}
pub fn quantize(&self, vector: &[f32]) -> Result<TernaryVector, RetrieveError> {
if vector.len() != self.dimension {
return Err(RetrieveError::Other(format!(
"Expected {} dimensions, got {}",
self.dimension,
vector.len()
)));
}
let centered: Vec<f32> = if let Some(ref mean) = self.mean {
vector
.iter()
.zip(mean.iter())
.map(|(&v, &m)| v - m)
.collect()
} else {
vector.to_vec()
};
let processed: Vec<f32> = if self.config.normalize {
let norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
centered.iter().map(|&x| x / norm).collect()
} else {
centered
}
} else {
centered
};
let original_norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
let num_bytes = self.dimension.div_ceil(4);
let mut data = vec![0u8; num_bytes];
let mut positive_count = 0;
let mut negative_count = 0;
for (i, &v) in processed.iter().enumerate() {
let (thresh_low, thresh_high) = if let Some(ref thresholds) = self.adaptive_thresholds {
thresholds[i]
} else {
(self.config.threshold_low, self.config.threshold_high)
};
let bits: u8 = if v > thresh_high {
positive_count += 1;
0b01 } else if v < thresh_low {
negative_count += 1;
0b10 } else {
0b00 };
let byte_idx = i / 4;
let bit_offset = (i % 4) * 2;
data[byte_idx] |= bits << bit_offset;
}
Ok(TernaryVector {
data,
dimension: self.dimension,
positive_count,
negative_count,
original_norm,
})
}
pub fn quantize_batch(
&self,
vectors: &[f32],
num_vectors: usize,
) -> Result<Vec<TernaryVector>, RetrieveError> {
if vectors.len() != num_vectors * self.dimension {
return Err(RetrieveError::Other("Vector count mismatch".to_string()));
}
(0..num_vectors)
.map(|i| {
let vec = &vectors[i * self.dimension..(i + 1) * self.dimension];
self.quantize(vec)
})
.collect()
}
}
pub fn ternary_inner_product(a: &TernaryVector, b: &TernaryVector) -> i32 {
if a.dimension != b.dimension {
return 0;
}
let mut sum: i32 = 0;
for (byte_a, byte_b) in a.data.iter().zip(b.data.iter()) {
for i in 0..4 {
let bits_a = (*byte_a >> (i * 2)) & 0b11;
let bits_b = (*byte_b >> (i * 2)) & 0b11;
let val_a = match bits_a {
0b01 => 1i32,
0b10 => -1,
_ => 0,
};
let val_b = match bits_b {
0b01 => 1i32,
0b10 => -1,
_ => 0,
};
sum += val_a * val_b;
}
}
sum
}
pub fn ternary_cosine_similarity(a: &TernaryVector, b: &TernaryVector) -> f32 {
let ip = ternary_inner_product(a, b) as f32;
let norm_a = ((a.positive_count + a.negative_count) as f32).sqrt();
let norm_b = ((b.positive_count + b.negative_count) as f32).sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 0.0;
}
ip / (norm_a * norm_b)
}
pub fn asymmetric_inner_product(query: &[f32], quantized: &TernaryVector) -> f32 {
if query.len() != quantized.dimension {
return 0.0;
}
let mut sum = 0.0f32;
for (i, &q) in query.iter().enumerate() {
let val = quantized.get(i);
sum += q * (val as f32);
}
sum
}
pub fn asymmetric_cosine_distance(query: &[f32], quantized: &TernaryVector) -> f32 {
let ip = asymmetric_inner_product(query, quantized);
let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let ternary_norm = ((quantized.positive_count + quantized.negative_count) as f32).sqrt();
if query_norm < 1e-10 || ternary_norm < 1e-10 {
return 1.0;
}
1.0 - (ip / (query_norm * ternary_norm))
}
pub fn ternary_hamming(a: &TernaryVector, b: &TernaryVector) -> usize {
if a.dimension != b.dimension {
return a.dimension.max(b.dimension);
}
let mut diff = 0;
for i in 0..a.dimension {
if a.get(i) != b.get(i) {
diff += 1;
}
}
diff
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_quantization() {
let quantizer = TernaryQuantizer::with_dimension(8);
let vector = vec![0.5, -0.5, 0.1, -0.1, 0.8, -0.8, 0.0, 0.2];
let quantized = quantizer.quantize(&vector).unwrap();
assert_eq!(quantized.dimension(), 8);
assert!(quantized.memory_bytes() <= 2); }
#[test]
fn test_ternary_values() {
let config = TernaryConfig {
threshold_high: 0.3,
threshold_low: -0.3,
normalize: false,
target_sparsity: None,
};
let quantizer = TernaryQuantizer::new(4, config);
let vector = vec![0.5, -0.5, 0.1, -0.1];
let quantized = quantizer.quantize(&vector).unwrap();
assert_eq!(quantized.get(0), 1); assert_eq!(quantized.get(1), -1); assert_eq!(quantized.get(2), 0); assert_eq!(quantized.get(3), 0); }
#[test]
fn test_inner_product() {
let config = TernaryConfig {
threshold_high: 0.3,
threshold_low: -0.3,
normalize: false,
target_sparsity: None,
};
let quantizer = TernaryQuantizer::new(4, config);
let v1 = vec![0.5, -0.5, 0.1, 0.0]; let v2 = vec![0.5, 0.5, -0.5, 0.0];
let q1 = quantizer.quantize(&v1).unwrap();
let q2 = quantizer.quantize(&v2).unwrap();
let ip = ternary_inner_product(&q1, &q2);
assert_eq!(ip, 0);
}
#[test]
fn test_asymmetric_distance() {
let config = TernaryConfig {
threshold_high: 0.3,
threshold_low: -0.3,
normalize: false,
target_sparsity: None,
};
let quantizer = TernaryQuantizer::new(4, config);
let vec = vec![0.5, -0.5, 0.1, 0.0]; let quantized = quantizer.quantize(&vec).unwrap();
let query = vec![1.0, 1.0, 1.0, 1.0];
let ip = asymmetric_inner_product(&query, &quantized);
assert_eq!(ip, 0.0);
let query2 = vec![1.0, -1.0, 0.0, 0.0];
let ip2 = asymmetric_inner_product(&query2, &quantized);
assert_eq!(ip2, 2.0);
}
#[test]
fn test_sparsity() {
let config = TernaryConfig {
threshold_high: 0.5,
threshold_low: -0.5,
normalize: false,
target_sparsity: None,
};
let quantizer = TernaryQuantizer::new(4, config);
let vec = vec![0.6, -0.6, 0.1, 0.2];
let quantized = quantizer.quantize(&vec).unwrap();
assert!((quantized.sparsity() - 0.5).abs() < 0.01);
}
#[test]
fn test_fit_adaptive() {
let mut quantizer = TernaryQuantizer::new(
4,
TernaryConfig {
target_sparsity: Some(0.5),
normalize: false,
..Default::default()
},
);
let vectors: Vec<f32> = vec![
0.1, 0.2, 0.3, 0.4, -0.1, -0.2, -0.3, -0.4, 0.5, 0.6, 0.7, 0.8, -0.5, -0.6, -0.7, -0.8,
];
quantizer.fit(&vectors, 4).unwrap();
assert!(quantizer.adaptive_thresholds.is_some());
}
#[test]
fn test_batch_quantize() {
let quantizer = TernaryQuantizer::with_dimension(4);
let vectors = vec![0.5, -0.5, 0.0, 0.0, -0.5, 0.5, 0.0, 0.0];
let quantized = quantizer.quantize_batch(&vectors, 2).unwrap();
assert_eq!(quantized.len(), 2);
assert_eq!(quantized[0].dimension(), 4);
assert_eq!(quantized[1].dimension(), 4);
}
#[test]
fn test_hamming_distance() {
let config = TernaryConfig {
threshold_high: 0.3,
threshold_low: -0.3,
normalize: false,
target_sparsity: None,
};
let quantizer = TernaryQuantizer::new(4, config);
let v1 = vec![0.5, -0.5, 0.0, 0.0]; let v2 = vec![0.5, 0.5, 0.0, -0.5];
let q1 = quantizer.quantize(&v1).unwrap();
let q2 = quantizer.quantize(&v2).unwrap();
let hamming = ternary_hamming(&q1, &q2);
assert_eq!(hamming, 2);
}
#[test]
fn test_memory_efficiency() {
let quantizer = TernaryQuantizer::with_dimension(1024);
let vector: Vec<f32> = (0..1024).map(|i| (i as f32 / 1024.0) - 0.5).collect();
let quantized = quantizer.quantize(&vector).unwrap();
assert_eq!(quantized.memory_bytes(), 256);
}
}