Skip to main content

god_graph/tensor/decomposition/
tensor_ring.rs

1//! Tensor Ring Decomposition
2//!
3//! Implements Tensor Ring (TR) decomposition:
4//! W(i₁,...,iₙ) = Σ Tr[G₁(i₁) × G₂(i₂) × ... × Gₙ(iₙ)]
5//!
6//! where Gₖ(iₖ) ∈ R^(rₖ₋₁×rₖ) are core tensors and rₖ are the TR ranks.
7//!
8//! ## Compression Ratio
9//!
10//! Original parameters: Π Iₖ
11//! TR parameters: Σ rₖ₋₁ × rₖ × Iₖ
12//! Compression ratio: Original / TR
13
14use crate::tensor::DenseTensor;
15use crate::tensor::TensorBase;
16use crate::tensor::TensorError;
17
18/// Tensor Ring decomposition result
19#[derive(Debug, Clone)]
20pub struct TensorRing {
21    /// Core tensors [G₁, G₂, ..., Gₙ]
22    pub cores: Vec<DenseTensor>,
23    /// TR ranks [r₀, r₁, ..., rₙ]
24    pub ranks: Vec<usize>,
25    /// Original tensor shape
26    pub original_shape: Vec<usize>,
27}
28
29impl TensorRing {
30    /// Create a new TensorRing decomposition
31    pub fn new(cores: Vec<DenseTensor>, ranks: Vec<usize>, original_shape: Vec<usize>) -> Self {
32        Self {
33            cores,
34            ranks,
35            original_shape,
36        }
37    }
38
39    /// Get the number of dimensions
40    pub fn ndim(&self) -> usize {
41        self.original_shape.len()
42    }
43
44    /// Get the compression ratio
45    pub fn compression_ratio(&self) -> f64 {
46        let original_params: usize = self.original_shape.iter().product();
47        let tr_params: usize = self
48            .cores
49            .iter()
50            .map(|c| c.shape().iter().product::<usize>())
51            .sum();
52
53        if tr_params == 0 {
54            return f64::MAX;
55        }
56        original_params as f64 / tr_params as f64
57    }
58
59    /// Reconstruct the original tensor from TR decomposition
60    pub fn reconstruct(&self) -> Result<DenseTensor, TensorError> {
61        tensor_ring_reconstruct(self)
62    }
63}
64
65/// Perform Tensor Ring decomposition on a tensor
66///
67/// # Arguments
68///
69/// * `tensor` - Input tensor to decompose
70/// * `ranks` - TR ranks [r₀, r₁, ..., rₙ] where n is the number of dimensions
71///
72/// # Returns
73///
74/// TensorRing decomposition result
75///
76/// # Algorithm
77///
78/// For a 2D weight matrix W ∈ R^(m×n), we treat it as a 2D tensor
79/// and decompose it into 2 core tensors with TR structure.
80pub fn tensor_ring_decompose(
81    tensor: &DenseTensor,
82    ranks: &[usize],
83) -> Result<TensorRing, TensorError> {
84    let shape = tensor.shape();
85    let ndim = shape.len();
86
87    if ranks.len() != ndim + 1 {
88        return Err(TensorError::DimensionMismatch {
89            expected: ranks.len(),
90            got: ndim + 1,
91        });
92    }
93
94    let mut cores = Vec::with_capacity(ndim);
95
96    if ndim == 2 {
97        // For 2D matrices: W ∈ R^(m×n)
98        // TR decomposition: G₁ ∈ R^(r₀×m×r₁), G₂ ∈ R^(r₁×n×r₀)
99        // Reconstruction: W(i,j) = Tr(G₁(:,i,:) × G₂(:,j,:))
100        let (m, n) = (shape[0], shape[1]);
101        let (r0, r1, r2) = (ranks[0], ranks[1], ranks[2]);
102        
103        // For TR, we need r0 == r2 (ring closure)
104        if r0 != r2 {
105            return Err(TensorError::ShapeMismatch {
106                expected: vec![r2],
107                got: vec![r0],
108            });
109        }
110        
111        // Use SVD-based initialization
112        let (u, s, v) = crate::tensor::decomposition::svd_decompose(tensor, Some(r1))?;
113        
114        let u_data = u.data();
115        let s_data = s.data();
116        let v_data = v.data();
117        
118        let k = r1; // truncated rank
119        
120        // Core 1: G₁ ∈ R^(r0 × m × r1)
121        // G₁(α, i, β) = U(i, β) * sqrt(S(β)) * δ(α, β) if r0 >= r1
122        // We use a simplified initialization: G₁(α, i, β) = U(i, β) * sqrt(S(β)) when α = β
123        let mut g1_data = vec![0.0; r0 * m * r1];
124        for alpha in 0..r0 {
125            for i in 0..m {
126                for beta in 0..r1 {
127                    if alpha == beta && alpha < k {
128                        g1_data[alpha * m * r1 + i * r1 + beta] = u_data[i * k + alpha] * s_data[alpha].sqrt();
129                    }
130                }
131            }
132        }
133        let g1 = DenseTensor::from_vec(g1_data, vec![r0, m, r1]);
134        
135        // Core 2: G₂ ∈ R^(r1 × n × r0)
136        // G₂(β, j, α) = V(j, β) * sqrt(S(β)) * δ(α, β)
137        let mut g2_data = vec![0.0; r1 * n * r0];
138        for beta in 0..r1 {
139            for j in 0..n {
140                for alpha in 0..r0 {
141                    if alpha == beta && beta < k {
142                        g2_data[beta * n * r0 + j * r0 + alpha] = v_data[j * k + beta] * s_data[beta].sqrt();
143                    }
144                }
145            }
146        }
147        let g2 = DenseTensor::from_vec(g2_data, vec![r1, n, r0]);
148        
149        cores.push(g1);
150        cores.push(g2);
151    } else {
152        return Err(TensorError::UnsupportedDType {
153            dtype: format!("ndim={}", ndim),
154            operation: "Tensor Ring decomposition for ndim > 2".to_string(),
155        });
156    }
157
158    Ok(TensorRing::new(cores, ranks.to_vec(), shape.to_vec()))
159}
160
161/// Reconstruct a tensor from its Tensor Ring decomposition
162///
163/// # Arguments
164///
165/// * `tr` - TensorRing decomposition to reconstruct
166///
167/// # Returns
168///
169/// Reconstructed tensor
170pub fn tensor_ring_reconstruct(tr: &TensorRing) -> Result<DenseTensor, TensorError> {
171    let ndim = tr.ndim();
172
173    if ndim == 2 && tr.cores.len() >= 2 {
174        // For 2D case: W(i,j) = Σ_{α,β} G₁(α,i,β) × G₂(β,j,α)
175        // This is the trace of the matrix product
176        let g1 = &tr.cores[0];
177        let g2 = &tr.cores[1];
178
179        let g1_shape = g1.shape();
180        let g2_shape = g2.shape();
181
182        let m = g1_shape[1]; // First dimension (from G1)
183        let n = g2_shape[1]; // Second dimension (from G2)
184        
185        let r0 = g1_shape[0]; // G1 first index
186        let r1 = g1_shape[2]; // G1 third index (should equal G2 first index)
187        
188        if r1 != g2_shape[0] {
189            return Err(TensorError::ShapeMismatch {
190                expected: vec![r1],
191                got: vec![g2_shape[0]],
192            });
193        }
194        
195        if r0 != g2_shape[2] {
196            return Err(TensorError::ShapeMismatch {
197                expected: vec![r0],
198                got: vec![g2_shape[2]],
199            });
200        }
201
202        let g1_data = g1.data();
203        let g2_data = g2.data();
204        let mut result = vec![0.0; m * n];
205
206        // Contract: W(i,j) = Σ_{α,β} G₁(α,i,β) × G₂(β,j,α)
207        for i in 0..m {
208            for j in 0..n {
209                let mut sum = 0.0;
210                for alpha in 0..r0 {
211                    for beta in 0..r1 {
212                        // G1(α, i, β)
213                        let g1_val = g1_data[alpha * m * r1 + i * r1 + beta];
214                        // G2(β, j, α)
215                        let g2_val = g2_data[beta * n * r0 + j * r0 + alpha];
216                        sum += g1_val * g2_val;
217                    }
218                }
219                result[i * n + j] = sum;
220            }
221        }
222
223        Ok(DenseTensor::from_vec(result, vec![m, n]))
224    } else {
225        Err(TensorError::UnsupportedDType {
226            dtype: format!("ndim={}", ndim),
227            operation: "Tensor Ring reconstruction".to_string(),
228        })
229    }
230}
231
232/// Compress a weight matrix using Tensor Ring decomposition
233///
234/// # Arguments
235///
236/// * `tensor` - Weight tensor to compress
237/// * `target_rank` - Target TR rank (controls compression vs accuracy)
238///
239/// # Returns
240///
241/// Compressed TensorRing representation
242pub fn compress_tensor_ring(
243    tensor: &DenseTensor,
244    target_rank: usize,
245) -> Result<TensorRing, TensorError> {
246    let shape = tensor.shape();
247
248    if shape.len() != 2 {
249        return Err(TensorError::DimensionMismatch {
250            expected: 2,
251            got: shape.len(),
252        });
253    }
254
255    // Use balanced ranks for simplicity: r0 = r1 = r2 = target_rank
256    let ranks = vec![target_rank, target_rank, target_rank];
257
258    tensor_ring_decompose(tensor, &ranks)
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_tensor_ring_2d() {
267        let tensor = DenseTensor::from_vec(
268            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
269            vec![4, 2],
270        );
271
272        let ranks = vec![2, 2, 2];
273        let tr = tensor_ring_decompose(&tensor, &ranks).unwrap();
274
275        assert_eq!(tr.cores.len(), 2);
276        assert_eq!(tr.ranks, ranks);
277        assert!(tr.compression_ratio() > 0.0);
278    }
279
280    #[test]
281    fn test_tensor_ring_reconstruct() {
282        // Use a low-rank matrix for perfect reconstruction
283        // Create a rank-1 matrix: outer product of [1, 2, 3, 4] and [1, 1]
284        let tensor = DenseTensor::from_vec(
285            vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0],
286            vec![4, 2],
287        );
288
289        let ranks = vec![2, 2, 2];
290        let tr = tensor_ring_decompose(&tensor, &ranks).unwrap();
291        let reconstructed = tr.reconstruct().unwrap();
292
293        assert_eq!(reconstructed.shape(), tensor.shape());
294
295        // Check reconstruction accuracy
296        let orig_data = tensor.data();
297        let recon_data = reconstructed.data();
298        let mse: f64 = orig_data
299            .iter()
300            .zip(recon_data.iter())
301            .map(|(a, b)| (a - b).powi(2))
302            .sum::<f64>()
303            / orig_data.len() as f64;
304
305        assert!(mse < 1e-6, "MSE too high: {}", mse);
306    }
307
308    #[test]
309    fn test_compression_ratio() {
310        // Create a simple matrix for compression test
311        let tensor = DenseTensor::from_vec(
312            vec![1.0; 64 * 64], // 4096 parameters
313            vec![64, 64],
314        );
315
316        let tr = compress_tensor_ring(&tensor, 8).unwrap();
317
318        // Original: 64 * 64 = 4096
319        // TR params: 8*64*8 + 8*64*8 = 4096 + 4096 = 8192 (no compression for this case)
320        // For rank-8 TR on 64x64, we expect compression when rank << min(m,n)
321        // Let's verify the calculation is correct
322        assert!(tr.compression_ratio() > 0.0);
323    }
324    
325    #[test]
326    fn test_tensor_ring_rank1() {
327        // Test with a pure rank-1 matrix for perfect TR reconstruction
328        let tensor = DenseTensor::from_vec(
329            vec![2.0, 4.0, 3.0, 6.0],
330            vec![2, 2],
331        );
332
333        let ranks = vec![1, 1, 1];
334        let tr = tensor_ring_decompose(&tensor, &ranks).unwrap();
335        let reconstructed = tr.reconstruct().unwrap();
336
337        let orig_data = tensor.data();
338        let recon_data = reconstructed.data();
339        
340        for (a, b) in orig_data.iter().zip(recon_data.iter()) {
341            assert!((a - b).abs() < 1e-4, "Mismatch: {} vs {}", a, b);
342        }
343    }
344}