use std::time::Instant;
use super::hadamard::{
hadamard_batch_inverse, hadamard_batch_transform, log2_exact, next_power_of_2,
pad_to_power_of_2, HadamardTransform,
};
use crate::error::{Result, RuvLLMError};
#[derive(Debug, Clone)]
pub enum IncoherenceEvent {
IncoherenceApplied {
num_elements: usize,
duration_us: u64,
required_padding: bool,
original_dim: usize,
padded_dim: usize,
max_before: f32,
max_after: f32,
},
IncoherenceRestored {
num_elements: usize,
duration_us: u64,
reconstruction_error: Option<f32>,
},
IncoherenceError {
message: String,
phase: IncoherencePhase,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IncoherencePhase {
Forward,
Inverse,
Init,
}
#[derive(Debug, Clone)]
pub struct IncoherenceConfig {
pub seed: Option<u64>,
pub randomized: bool,
pub compute_stats: bool,
pub emit_events: bool,
pub min_dimension: usize,
pub batch_mode: bool,
}
impl Default for IncoherenceConfig {
fn default() -> Self {
Self {
seed: Some(42), randomized: true,
compute_stats: true,
emit_events: true,
min_dimension: 16, batch_mode: true,
}
}
}
impl IncoherenceConfig {
pub fn quality() -> Self {
Self {
seed: Some(12345),
randomized: true,
compute_stats: true,
emit_events: true,
min_dimension: 8,
batch_mode: true,
}
}
pub fn performance() -> Self {
Self {
seed: None,
randomized: false,
compute_stats: false,
emit_events: false,
min_dimension: 32,
batch_mode: true,
}
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn with_randomized(mut self, randomized: bool) -> Self {
self.randomized = randomized;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct IncoherenceStats {
pub forward_count: usize,
pub inverse_count: usize,
pub total_elements: u64,
pub forward_time_us: u64,
pub inverse_time_us: u64,
pub avg_outlier_reduction: f32,
pub padded_count: usize,
}
pub struct IncoherenceTransform {
config: IncoherenceConfig,
transforms: std::collections::HashMap<u32, HadamardTransform>,
stats: IncoherenceStats,
events: Vec<IncoherenceEvent>,
pending_original_dims: std::collections::HashMap<usize, usize>,
}
impl IncoherenceTransform {
pub fn new(config: IncoherenceConfig) -> Result<Self> {
Ok(Self {
config,
transforms: std::collections::HashMap::new(),
stats: IncoherenceStats::default(),
events: Vec::new(),
pending_original_dims: std::collections::HashMap::new(),
})
}
pub fn with_defaults() -> Result<Self> {
Self::new(IncoherenceConfig::default())
}
fn get_or_create_transform(&mut self, log_dim: u32) -> Result<&HadamardTransform> {
if !self.transforms.contains_key(&log_dim) {
let transform = if self.config.randomized {
HadamardTransform::randomized(log_dim, self.config.seed.unwrap_or(42))?
} else {
HadamardTransform::deterministic(log_dim)?
};
self.transforms.insert(log_dim, transform);
}
Ok(self.transforms.get(&log_dim).unwrap())
}
pub fn apply_before_quantization(&mut self, data: &mut Vec<f32>) -> Result<usize> {
let start = Instant::now();
let original_len = data.len();
if original_len < self.config.min_dimension {
return Ok(original_len);
}
let max_before = if self.config.compute_stats {
data.iter().map(|x| x.abs()).fold(0.0f32, |a, b| a.max(b))
} else {
0.0
};
let target_len = next_power_of_2(original_len);
let required_padding = target_len != original_len;
if required_padding {
data.resize(target_len, 0.0);
}
let log_dim = match log2_exact(target_len) {
Some(ld) => ld,
None => {
self.emit_error(
"Internal error: padded length not power of 2",
IncoherencePhase::Forward,
);
return Err(RuvLLMError::Quantization(
"Padded length is not a power of 2".to_string(),
));
}
};
let transform = self.get_or_create_transform(log_dim)?.clone();
transform.forward_inplace(data);
let max_after = if self.config.compute_stats {
data.iter().map(|x| x.abs()).fold(0.0f32, |a, b| a.max(b))
} else {
0.0
};
let data_id = data.as_ptr() as usize;
self.pending_original_dims.insert(data_id, original_len);
let duration_us = start.elapsed().as_micros() as u64;
self.stats.forward_count += 1;
self.stats.total_elements += target_len as u64;
self.stats.forward_time_us += duration_us;
if required_padding {
self.stats.padded_count += 1;
}
if max_before > 0.0 && max_after > 0.0 {
let reduction = max_before / max_after;
let n = self.stats.forward_count as f32;
self.stats.avg_outlier_reduction =
(self.stats.avg_outlier_reduction * (n - 1.0) + reduction) / n;
}
if self.config.emit_events {
self.events.push(IncoherenceEvent::IncoherenceApplied {
num_elements: target_len,
duration_us,
required_padding,
original_dim: original_len,
padded_dim: target_len,
max_before,
max_after,
});
}
Ok(target_len)
}
pub fn restore_after_dequantization(
&mut self,
data: &mut Vec<f32>,
original_len: Option<usize>,
) -> Result<()> {
let start = Instant::now();
let current_len = data.len();
let log_dim = match log2_exact(current_len) {
Some(ld) => ld,
None => {
self.emit_error("Data length is not a power of 2", IncoherencePhase::Inverse);
return Err(RuvLLMError::Quantization(
"Data length must be a power of 2 for inverse transform".to_string(),
));
}
};
let transform = self.get_or_create_transform(log_dim)?.clone();
transform.inverse_inplace(data);
let final_len = original_len.unwrap_or_else(|| {
let data_id = data.as_ptr() as usize;
self.pending_original_dims
.remove(&data_id)
.unwrap_or(current_len)
});
if final_len < current_len {
data.truncate(final_len);
}
let duration_us = start.elapsed().as_micros() as u64;
self.stats.inverse_count += 1;
self.stats.inverse_time_us += duration_us;
if self.config.emit_events {
self.events.push(IncoherenceEvent::IncoherenceRestored {
num_elements: final_len,
duration_us,
reconstruction_error: None, });
}
Ok(())
}
pub fn apply_batch(&mut self, data: &mut [f32], dim: usize, batch_size: usize) -> Result<()> {
if data.len() != dim * batch_size {
return Err(RuvLLMError::Quantization(format!(
"Data length {} doesn't match dim {} * batch_size {}",
data.len(),
dim,
batch_size
)));
}
let log_dim = match log2_exact(dim) {
Some(ld) => ld,
None => {
return Err(RuvLLMError::Quantization(
"Dimension must be a power of 2 for batch transform".to_string(),
));
}
};
let transform = self.get_or_create_transform(log_dim)?.clone();
hadamard_batch_transform(&transform, data, batch_size)?;
self.stats.forward_count += batch_size;
self.stats.total_elements += (dim * batch_size) as u64;
Ok(())
}
pub fn restore_batch(&mut self, data: &mut [f32], dim: usize, batch_size: usize) -> Result<()> {
if data.len() != dim * batch_size {
return Err(RuvLLMError::Quantization(format!(
"Data length {} doesn't match dim {} * batch_size {}",
data.len(),
dim,
batch_size
)));
}
let log_dim = match log2_exact(dim) {
Some(ld) => ld,
None => {
return Err(RuvLLMError::Quantization(
"Dimension must be a power of 2 for batch inverse".to_string(),
));
}
};
let transform = self.get_or_create_transform(log_dim)?.clone();
hadamard_batch_inverse(&transform, data, batch_size)?;
self.stats.inverse_count += batch_size;
Ok(())
}
pub fn stats(&self) -> &IncoherenceStats {
&self.stats
}
pub fn take_events(&mut self) -> Vec<IncoherenceEvent> {
std::mem::take(&mut self.events)
}
pub fn events(&self) -> &[IncoherenceEvent] {
&self.events
}
pub fn config(&self) -> &IncoherenceConfig {
&self.config
}
pub fn reset_stats(&mut self) {
self.stats = IncoherenceStats::default();
}
pub fn clear_cache(&mut self) {
self.transforms.clear();
}
fn emit_error(&mut self, message: &str, phase: IncoherencePhase) {
if self.config.emit_events {
self.events.push(IncoherenceEvent::IncoherenceError {
message: message.to_string(),
phase,
});
}
}
pub fn verify(&mut self, dim: usize, tolerance: f32) -> Result<bool> {
let log_dim = match log2_exact(dim) {
Some(ld) => ld,
None => {
return Err(RuvLLMError::Quantization(
"Dimension must be a power of 2 for verification".to_string(),
));
}
};
let transform = self.get_or_create_transform(log_dim)?;
Ok(transform.verify_orthogonality(tolerance))
}
}
pub fn apply_incoherence(data: &mut Vec<f32>, seed: Option<u64>) -> Result<usize> {
let config = IncoherenceConfig {
seed,
randomized: seed.is_some(),
compute_stats: false,
emit_events: false,
min_dimension: 8,
batch_mode: false,
};
let mut transform = IncoherenceTransform::new(config)?;
transform.apply_before_quantization(data)
}
pub fn restore_incoherence(
data: &mut Vec<f32>,
original_len: usize,
seed: Option<u64>,
) -> Result<()> {
let config = IncoherenceConfig {
seed,
randomized: seed.is_some(),
compute_stats: false,
emit_events: false,
min_dimension: 8,
batch_mode: false,
};
let mut transform = IncoherenceTransform::new(config)?;
transform.restore_after_dequantization(data, Some(original_len))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_incoherence_basic() {
let config = IncoherenceConfig {
min_dimension: 4,
..Default::default()
};
let mut transform = IncoherenceTransform::new(config).unwrap();
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut data = original.clone();
let padded_dim = transform.apply_before_quantization(&mut data).unwrap();
assert_eq!(padded_dim, 8);
transform
.restore_after_dequantization(&mut data, Some(8))
.unwrap();
for (a, b) in data.iter().zip(original.iter()) {
assert!((a - b).abs() < 1e-5, "Roundtrip failed: {} vs {}", a, b);
}
}
#[test]
fn test_incoherence_with_padding() {
let config = IncoherenceConfig {
min_dimension: 4,
..Default::default()
};
let mut transform = IncoherenceTransform::new(config).unwrap();
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let original_len = original.len();
let mut data = original.clone();
let padded_dim = transform.apply_before_quantization(&mut data).unwrap();
assert_eq!(padded_dim, 8);
assert_eq!(data.len(), 8);
transform
.restore_after_dequantization(&mut data, Some(original_len))
.unwrap();
assert_eq!(data.len(), original_len);
for (a, b) in data.iter().zip(original.iter()) {
assert!(
(a - b).abs() < 1e-5,
"Padded roundtrip failed: {} vs {}",
a,
b
);
}
}
#[test]
fn test_outlier_spreading() {
let config = IncoherenceConfig {
seed: Some(42),
randomized: true,
compute_stats: true,
emit_events: true,
min_dimension: 4,
batch_mode: false,
};
let mut transform = IncoherenceTransform::new(config).unwrap();
let mut data: Vec<f32> = vec![1.0, 1.0, 1.0, 100.0, 1.0, 1.0, 1.0, 1.0];
let max_before: f32 = data
.iter()
.map(|x: &f32| x.abs())
.fold(0.0f32, |a: f32, b: f32| a.max(b));
transform.apply_before_quantization(&mut data).unwrap();
let max_after: f32 = data
.iter()
.map(|x: &f32| x.abs())
.fold(0.0f32, |a: f32, b: f32| a.max(b));
assert!(
max_after < max_before * 0.9,
"Outlier not spread: before={}, after={}",
max_before,
max_after
);
let events = transform.take_events();
assert!(!events.is_empty());
if let IncoherenceEvent::IncoherenceApplied {
max_before: mb,
max_after: ma,
..
} = &events[0]
{
assert!((*ma) < (*mb) * 0.9);
}
}
#[test]
fn test_batch_transform() {
let mut transform = IncoherenceTransform::with_defaults().unwrap();
let dim = 16;
let batch_size = 4;
let original: Vec<f32> = (0..dim * batch_size).map(|i| i as f32).collect();
let mut data = original.clone();
transform.apply_batch(&mut data, dim, batch_size).unwrap();
transform.restore_batch(&mut data, dim, batch_size).unwrap();
for (a, b) in data.iter().zip(original.iter()) {
assert!((a - b).abs() < 1e-4, "Batch roundtrip failed");
}
}
#[test]
fn test_verify() {
let mut transform = IncoherenceTransform::with_defaults().unwrap();
assert!(transform.verify(64, 1e-5).unwrap());
}
#[test]
fn test_statistics() {
let config = IncoherenceConfig {
seed: Some(42),
randomized: true,
compute_stats: true,
emit_events: true,
min_dimension: 4,
batch_mode: false,
};
let mut transform = IncoherenceTransform::new(config).unwrap();
let mut data = vec![1.0, 2.0, 3.0, 4.0];
transform.apply_before_quantization(&mut data).unwrap();
let stats = transform.stats();
assert_eq!(stats.forward_count, 1);
assert_eq!(stats.total_elements, 4);
assert!(stats.forward_time_us > 0 || stats.forward_time_us == 0); }
#[test]
fn test_skip_small_tensors() {
let config = IncoherenceConfig {
min_dimension: 32,
..Default::default()
};
let mut transform = IncoherenceTransform::new(config).unwrap();
let original = vec![1.0, 2.0, 3.0, 4.0]; let mut data = original.clone();
let padded_dim = transform.apply_before_quantization(&mut data).unwrap();
assert_eq!(padded_dim, 4);
assert_eq!(data, original); }
#[test]
fn test_config_builders() {
let quality = IncoherenceConfig::quality();
assert!(quality.randomized);
assert!(quality.compute_stats);
let perf = IncoherenceConfig::performance();
assert!(!perf.randomized);
assert!(!perf.compute_stats);
}
#[test]
fn test_convenience_functions() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let original_len = original.len();
let mut data = original.clone();
let _padded = apply_incoherence(&mut data, Some(12345)).unwrap();
restore_incoherence(&mut data, original_len, Some(12345)).unwrap();
for (a, b) in data.iter().zip(original.iter()) {
assert!((a - b).abs() < 1e-5);
}
}
#[test]
fn test_energy_preservation_through_pipeline() {
let mut transform = IncoherenceTransform::with_defaults().unwrap();
let original: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 10.0).collect();
let original_energy: f32 = original.iter().map(|x| x * x).sum();
let mut data = original.clone();
transform.apply_before_quantization(&mut data).unwrap();
let transformed_energy: f32 = data.iter().map(|x| x * x).sum();
let relative_error = (original_energy - transformed_energy).abs() / original_energy;
assert!(
relative_error < 0.01,
"Energy not preserved: original={}, transformed={}, error={}",
original_energy,
transformed_energy,
relative_error
);
}
}