god_graph/tensor/decomposition/
tensor_ring.rs1use crate::tensor::DenseTensor;
15use crate::tensor::TensorBase;
16use crate::tensor::TensorError;
17
18#[derive(Debug, Clone)]
20pub struct TensorRing {
21 pub cores: Vec<DenseTensor>,
23 pub ranks: Vec<usize>,
25 pub original_shape: Vec<usize>,
27}
28
29impl TensorRing {
30 pub fn new(cores: Vec<DenseTensor>, ranks: Vec<usize>, original_shape: Vec<usize>) -> Self {
32 Self {
33 cores,
34 ranks,
35 original_shape,
36 }
37 }
38
39 pub fn ndim(&self) -> usize {
41 self.original_shape.len()
42 }
43
44 pub fn compression_ratio(&self) -> f64 {
46 let original_params: usize = self.original_shape.iter().product();
47 let tr_params: usize = self
48 .cores
49 .iter()
50 .map(|c| c.shape().iter().product::<usize>())
51 .sum();
52
53 if tr_params == 0 {
54 return f64::MAX;
55 }
56 original_params as f64 / tr_params as f64
57 }
58
59 pub fn reconstruct(&self) -> Result<DenseTensor, TensorError> {
61 tensor_ring_reconstruct(self)
62 }
63}
64
65pub fn tensor_ring_decompose(
81 tensor: &DenseTensor,
82 ranks: &[usize],
83) -> Result<TensorRing, TensorError> {
84 let shape = tensor.shape();
85 let ndim = shape.len();
86
87 if ranks.len() != ndim + 1 {
88 return Err(TensorError::DimensionMismatch {
89 expected: ranks.len(),
90 got: ndim + 1,
91 });
92 }
93
94 let mut cores = Vec::with_capacity(ndim);
95
96 if ndim == 2 {
97 let (m, n) = (shape[0], shape[1]);
101 let (r0, r1, r2) = (ranks[0], ranks[1], ranks[2]);
102
103 if r0 != r2 {
105 return Err(TensorError::ShapeMismatch {
106 expected: vec![r2],
107 got: vec![r0],
108 });
109 }
110
111 let (u, s, v) = crate::tensor::decomposition::svd_decompose(tensor, Some(r1))?;
113
114 let u_data = u.data();
115 let s_data = s.data();
116 let v_data = v.data();
117
118 let k = r1; let mut g1_data = vec![0.0; r0 * m * r1];
124 for alpha in 0..r0 {
125 for i in 0..m {
126 for beta in 0..r1 {
127 if alpha == beta && alpha < k {
128 g1_data[alpha * m * r1 + i * r1 + beta] = u_data[i * k + alpha] * s_data[alpha].sqrt();
129 }
130 }
131 }
132 }
133 let g1 = DenseTensor::from_vec(g1_data, vec![r0, m, r1]);
134
135 let mut g2_data = vec![0.0; r1 * n * r0];
138 for beta in 0..r1 {
139 for j in 0..n {
140 for alpha in 0..r0 {
141 if alpha == beta && beta < k {
142 g2_data[beta * n * r0 + j * r0 + alpha] = v_data[j * k + beta] * s_data[beta].sqrt();
143 }
144 }
145 }
146 }
147 let g2 = DenseTensor::from_vec(g2_data, vec![r1, n, r0]);
148
149 cores.push(g1);
150 cores.push(g2);
151 } else {
152 return Err(TensorError::UnsupportedDType {
153 dtype: format!("ndim={}", ndim),
154 operation: "Tensor Ring decomposition for ndim > 2".to_string(),
155 });
156 }
157
158 Ok(TensorRing::new(cores, ranks.to_vec(), shape.to_vec()))
159}
160
161pub fn tensor_ring_reconstruct(tr: &TensorRing) -> Result<DenseTensor, TensorError> {
171 let ndim = tr.ndim();
172
173 if ndim == 2 && tr.cores.len() >= 2 {
174 let g1 = &tr.cores[0];
177 let g2 = &tr.cores[1];
178
179 let g1_shape = g1.shape();
180 let g2_shape = g2.shape();
181
182 let m = g1_shape[1]; let n = g2_shape[1]; let r0 = g1_shape[0]; let r1 = g1_shape[2]; if r1 != g2_shape[0] {
189 return Err(TensorError::ShapeMismatch {
190 expected: vec![r1],
191 got: vec![g2_shape[0]],
192 });
193 }
194
195 if r0 != g2_shape[2] {
196 return Err(TensorError::ShapeMismatch {
197 expected: vec![r0],
198 got: vec![g2_shape[2]],
199 });
200 }
201
202 let g1_data = g1.data();
203 let g2_data = g2.data();
204 let mut result = vec![0.0; m * n];
205
206 for i in 0..m {
208 for j in 0..n {
209 let mut sum = 0.0;
210 for alpha in 0..r0 {
211 for beta in 0..r1 {
212 let g1_val = g1_data[alpha * m * r1 + i * r1 + beta];
214 let g2_val = g2_data[beta * n * r0 + j * r0 + alpha];
216 sum += g1_val * g2_val;
217 }
218 }
219 result[i * n + j] = sum;
220 }
221 }
222
223 Ok(DenseTensor::from_vec(result, vec![m, n]))
224 } else {
225 Err(TensorError::UnsupportedDType {
226 dtype: format!("ndim={}", ndim),
227 operation: "Tensor Ring reconstruction".to_string(),
228 })
229 }
230}
231
232pub fn compress_tensor_ring(
243 tensor: &DenseTensor,
244 target_rank: usize,
245) -> Result<TensorRing, TensorError> {
246 let shape = tensor.shape();
247
248 if shape.len() != 2 {
249 return Err(TensorError::DimensionMismatch {
250 expected: 2,
251 got: shape.len(),
252 });
253 }
254
255 let ranks = vec![target_rank, target_rank, target_rank];
257
258 tensor_ring_decompose(tensor, &ranks)
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_tensor_ring_2d() {
267 let tensor = DenseTensor::from_vec(
268 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
269 vec![4, 2],
270 );
271
272 let ranks = vec![2, 2, 2];
273 let tr = tensor_ring_decompose(&tensor, &ranks).unwrap();
274
275 assert_eq!(tr.cores.len(), 2);
276 assert_eq!(tr.ranks, ranks);
277 assert!(tr.compression_ratio() > 0.0);
278 }
279
280 #[test]
281 fn test_tensor_ring_reconstruct() {
282 let tensor = DenseTensor::from_vec(
285 vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0],
286 vec![4, 2],
287 );
288
289 let ranks = vec![2, 2, 2];
290 let tr = tensor_ring_decompose(&tensor, &ranks).unwrap();
291 let reconstructed = tr.reconstruct().unwrap();
292
293 assert_eq!(reconstructed.shape(), tensor.shape());
294
295 let orig_data = tensor.data();
297 let recon_data = reconstructed.data();
298 let mse: f64 = orig_data
299 .iter()
300 .zip(recon_data.iter())
301 .map(|(a, b)| (a - b).powi(2))
302 .sum::<f64>()
303 / orig_data.len() as f64;
304
305 assert!(mse < 1e-6, "MSE too high: {}", mse);
306 }
307
308 #[test]
309 fn test_compression_ratio() {
310 let tensor = DenseTensor::from_vec(
312 vec![1.0; 64 * 64], vec![64, 64],
314 );
315
316 let tr = compress_tensor_ring(&tensor, 8).unwrap();
317
318 assert!(tr.compression_ratio() > 0.0);
323 }
324
325 #[test]
326 fn test_tensor_ring_rank1() {
327 let tensor = DenseTensor::from_vec(
329 vec![2.0, 4.0, 3.0, 6.0],
330 vec![2, 2],
331 );
332
333 let ranks = vec![1, 1, 1];
334 let tr = tensor_ring_decompose(&tensor, &ranks).unwrap();
335 let reconstructed = tr.reconstruct().unwrap();
336
337 let orig_data = tensor.data();
338 let recon_data = reconstructed.data();
339
340 for (a, b) in orig_data.iter().zip(recon_data.iter()) {
341 assert!((a - b).abs() < 1e-4, "Mismatch: {} vs {}", a, b);
342 }
343 }
344}