use vyre_primitives::math::tensor_train_decompose::cpu_ref;
#[derive(Debug, Clone)]
pub struct CompressedCostTensor {
pub cores: Vec<Vec<f64>>,
pub dims: Vec<u32>,
pub ranks: Vec<u32>,
}
#[must_use]
pub fn compress_cost_tensor(
tensor: &[f64],
dims: &[u32],
target_ranks: &[u32],
) -> CompressedCostTensor {
use crate::observability::{bump, tensor_train_compression_calls};
bump(&tensor_train_compression_calls);
let cores = cpu_ref(tensor, dims, target_ranks);
CompressedCostTensor {
cores,
dims: dims.to_vec(),
ranks: target_ranks.to_vec(),
}
}
#[must_use]
pub fn compression_ratio(compressed: &CompressedCostTensor) -> f64 {
let original_size: usize = if compressed.dims.is_empty() {
0
} else {
compressed.dims.iter().map(|d| *d as usize).product()
};
if original_size == 0 {
return 0.0;
}
let tt_size: usize = compressed.cores.iter().map(Vec::len).sum();
1.0 - (tt_size as f64) / (original_size as f64)
}
#[must_use]
pub fn tt_storage_size(compressed: &CompressedCostTensor) -> usize {
compressed.cores.iter().map(Vec::len).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compresses_3_mode_tensor() {
let dims = vec![2u32, 3, 2];
let target_ranks = vec![1u32, 2, 2, 1];
let tensor: Vec<f64> = (0..12).map(|i| i as f64).collect();
let compressed = compress_cost_tensor(&tensor, &dims, &target_ranks);
assert_eq!(compressed.cores.len(), 3); assert_eq!(compressed.dims, dims);
}
#[test]
fn compression_ratio_is_in_unit_interval() {
let dims = vec![4u32, 4];
let target_ranks = vec![1u32, 2, 1];
let tensor = vec![1.0; 16];
let compressed = compress_cost_tensor(&tensor, &dims, &target_ranks);
let ratio = compression_ratio(&compressed);
assert!(
(-1.0..=1.0).contains(&ratio),
"ratio out of expected range: {ratio}"
);
}
#[test]
fn tt_storage_size_returns_sum() {
let compressed = CompressedCostTensor {
cores: vec![vec![1.0; 4], vec![1.0; 8], vec![1.0; 4]],
dims: vec![2, 4, 2],
ranks: vec![1, 2, 2, 1],
};
assert_eq!(tt_storage_size(&compressed), 16);
}
#[test]
fn empty_dims_handled() {
let compressed = CompressedCostTensor {
cores: Vec::new(),
dims: Vec::new(),
ranks: vec![1],
};
assert_eq!(tt_storage_size(&compressed), 0);
assert_eq!(compression_ratio(&compressed), 0.0);
}
}