Skip to main content

god_gragh/tensor/
sparse.rs

1//! 稀疏张量实现
2//!
3//! 提供 COO(Coordinate)和 CSR(Compressed Sparse Row)格式
4//! 用于高效的图神经网络计算
5
6use core::fmt;
7
8use crate::tensor::dense::DenseTensor;
9use crate::tensor::error::TensorError;
10use crate::tensor::traits::{COOView, DType, Device, SparseTensorOps, TensorBase};
11
12/// COO(Coordinate)格式稀疏张量
13#[derive(Debug, Clone)]
14pub struct COOTensor {
15    row_indices: Vec<usize>,
16    col_indices: Vec<usize>,
17    values: DenseTensor,
18    shape: [usize; 2],
19}
20
21#[cfg(feature = "tensor-sparse")]
22impl COOTensor {
23    /// 创建新的 COO 张量
24    pub fn new(
25        row_indices: Vec<usize>,
26        col_indices: Vec<usize>,
27        values: DenseTensor,
28        shape: [usize; 2],
29    ) -> Self {
30        assert_eq!(
31            row_indices.len(),
32            col_indices.len(),
33            "Row and column indices must have the same length"
34        );
35        assert_eq!(
36            row_indices.len(),
37            values.numel(),
38            "Indices length must match values length"
39        );
40        Self {
41            row_indices,
42            col_indices,
43            values,
44            shape,
45        }
46    }
47
48    /// 获取非零元素数量
49    pub fn nnz(&self) -> usize {
50        self.values.numel()
51    }
52
53    /// 从边列表创建 COO 张量
54    pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
55        let row_indices: Vec<usize> = edges.iter().map(|&(r, _, _)| r).collect();
56        let col_indices: Vec<usize> = edges.iter().map(|&(_, c, _)| c).collect();
57        let values_data: Vec<f64> = edges.iter().map(|&(_, _, v)| v).collect();
58        let values = DenseTensor::new(values_data, vec![edges.len()]);
59        Self::new(row_indices, col_indices, values, shape)
60    }
61
62    /// 获取行索引
63    pub fn row_indices(&self) -> &[usize] {
64        &self.row_indices
65    }
66
67    /// 获取列索引
68    pub fn col_indices(&self) -> &[usize] {
69        &self.col_indices
70    }
71
72    /// 获取值
73    pub fn values(&self) -> &DenseTensor {
74        &self.values
75    }
76}
77
78/// CSR(Compressed Sparse Row)格式稀疏张量
79#[derive(Debug, Clone)]
80pub struct CSRTensor {
81    row_offsets: Vec<usize>,
82    col_indices: Vec<usize>,
83    values: DenseTensor,
84    shape: [usize; 2],
85}
86
87#[cfg(feature = "tensor-sparse")]
88impl CSRTensor {
89    /// 创建新的 CSR 张量
90    pub fn new(
91        row_offsets: Vec<usize>,
92        col_indices: Vec<usize>,
93        values: DenseTensor,
94        shape: [usize; 2],
95    ) -> Self {
96        assert_eq!(
97            col_indices.len(),
98            values.numel(),
99            "Column indices length must match values length"
100        );
101        Self {
102            row_offsets,
103            col_indices,
104            values,
105            shape,
106        }
107    }
108
109    /// 获取非零元素数量
110    pub fn nnz(&self) -> usize {
111        self.values.numel()
112    }
113
114    /// 从 COO 张量转换为 CSR
115    pub fn from_coo(coo: &COOTensor) -> Self {
116        let mut row_offsets = vec![0; coo.shape[0] + 1];
117        let mut col_indices = vec![0; coo.nnz()];
118        let mut values_data = vec![0.0; coo.nnz()];
119
120        // 计算每行的非零元素数量
121        for &row in &coo.row_indices {
122            row_offsets[row + 1] += 1;
123        }
124
125        // 转换为偏移量
126        for i in 1..row_offsets.len() {
127            row_offsets[i] += row_offsets[i - 1];
128        }
129
130        // 填充列索引和值
131        let mut row_pos = row_offsets.clone();
132        for (i, (&row, &col)) in coo
133            .row_indices
134            .iter()
135            .zip(coo.col_indices.iter())
136            .enumerate()
137        {
138            let pos = row_pos[row];
139            col_indices[pos] = col;
140            values_data[pos] = coo.values.data()[i];
141            row_pos[row] += 1;
142        }
143
144        let values = DenseTensor::new(values_data, vec![coo.nnz()]);
145        Self::new(row_offsets, col_indices, values, coo.shape)
146    }
147
148    /// 获取行偏移量
149    pub fn row_offsets(&self) -> &[usize] {
150        &self.row_offsets
151    }
152
153    /// 获取列索引
154    pub fn col_indices(&self) -> &[usize] {
155        &self.col_indices
156    }
157
158    /// 获取值
159    pub fn values(&self) -> &DenseTensor {
160        &self.values
161    }
162}
163
164/// 稀疏张量枚举:支持多种稀疏格式
165#[derive(Clone)]
166pub enum SparseTensor {
167    /// COO(Coordinate)格式
168    COO(COOTensor),
169    /// CSR(Compressed Sparse Row)格式
170    CSR(CSRTensor),
171}
172
173#[cfg(feature = "tensor-sparse")]
174impl SparseTensor {
175    /// 创建 COO 格式稀疏张量
176    pub fn coo(
177        row_indices: Vec<usize>,
178        col_indices: Vec<usize>,
179        values: DenseTensor,
180        shape: [usize; 2],
181    ) -> Self {
182        SparseTensor::COO(COOTensor::new(row_indices, col_indices, values, shape))
183    }
184
185    /// 创建 CSR 格式稀疏张量
186    pub fn csr(
187        row_offsets: Vec<usize>,
188        col_indices: Vec<usize>,
189        values: DenseTensor,
190        shape: [usize; 2],
191    ) -> Self {
192        SparseTensor::CSR(CSRTensor::new(row_offsets, col_indices, values, shape))
193    }
194
195    /// 获取非零元素数量
196    pub fn nnz(&self) -> usize {
197        match self {
198            SparseTensor::COO(coo) => coo.nnz(),
199            SparseTensor::CSR(csr) => csr.nnz(),
200        }
201    }
202
203    /// 转换为 CSR 格式
204    pub fn to_csr(&self) -> CSRTensor {
205        match self {
206            SparseTensor::COO(coo) => CSRTensor::from_coo(coo),
207            SparseTensor::CSR(csr) => csr.clone(),
208        }
209    }
210
211    /// 转换为 COO 格式
212    pub fn to_coo(&self) -> COOTensor {
213        match self {
214            SparseTensor::COO(coo) => coo.clone(),
215            SparseTensor::CSR(csr) => {
216                // CSR 转 COO
217                let mut row_indices = Vec::with_capacity(csr.nnz());
218                let col_indices = csr.col_indices.clone();
219                let mut values_data = Vec::with_capacity(csr.nnz());
220
221                for row in 0..csr.shape[0] {
222                    let start = csr.row_offsets[row];
223                    let end = csr.row_offsets[row + 1];
224                    for _ in start..end {
225                        row_indices.push(row);
226                    }
227                    for i in start..end {
228                        values_data.push(csr.values.data()[i]);
229                    }
230                }
231
232                let values = DenseTensor::new(values_data, vec![csr.nnz()]);
233                COOTensor::new(row_indices, col_indices, values, csr.shape)
234            }
235        }
236    }
237
238    /// 获取 COO 视图
239    pub fn coo_view(&self) -> COOView<'_> {
240        match self {
241            SparseTensor::COO(coo) => COOView::new(
242                &coo.row_indices,
243                &coo.col_indices,
244                coo.values.data(),
245                coo.shape,
246            ),
247            SparseTensor::CSR(_) => {
248                // For CSR, we need to convert to COO first, but we can't return a view
249                // So we return an empty view as a workaround (this is a limitation)
250                COOView::new(&[], &[], &[], [0, 0])
251            }
252        }
253    }
254
255    /// 从边列表创建稀疏张量(COO 格式)
256    pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
257        SparseTensor::COO(COOTensor::from_edges(edges, shape))
258    }
259
260    /// 稀疏矩阵 - 稠密向量乘法
261    pub fn spmv(&self, x: &DenseTensor) -> Result<DenseTensor, TensorError> {
262        if self.ndim() != 2 {
263            return Err(TensorError::DimensionMismatch {
264                expected: 2,
265                got: self.ndim(),
266            });
267        }
268
269        let shape = self.shape();
270        let rows = shape[0];
271        let cols = shape[1];
272
273        if x.shape() != [cols] {
274            return Err(TensorError::ShapeMismatch {
275                expected: vec![cols],
276                got: x.shape().to_vec(),
277            });
278        }
279
280        let mut result = vec![0.0; rows];
281        let coo = self.to_coo();
282
283        for (i, (&row, &col)) in coo
284            .row_indices
285            .iter()
286            .zip(coo.col_indices.iter())
287            .enumerate()
288        {
289            let val = coo.values.data()[i];
290            let x_val = x.data()[col];
291            result[row] += val * x_val;
292        }
293
294        Ok(DenseTensor::new(result, vec![rows]))
295    }
296
297    /// 稀疏矩阵 - 稀疏矩阵乘法
298    pub fn spmm(&self, other: &Self) -> Result<Self, TensorError> {
299        let shape_a = self.shape();
300        let shape_b = other.shape();
301        let (rows_a, cols_a) = (shape_a[0], shape_a[1]);
302        let (rows_b, cols_b) = (shape_b[0], shape_b[1]);
303
304        if cols_a != rows_b {
305            return Err(TensorError::ShapeMismatch {
306                expected: vec![cols_a],
307                got: vec![rows_b],
308            });
309        }
310
311        // 转换为 COO 进行乘法
312        let coo_a = self.to_coo();
313        let coo_b = other.to_coo();
314
315        // 使用哈希表累加结果
316        use std::collections::HashMap;
317        let mut result_map: HashMap<(usize, usize), f64> = HashMap::new();
318
319        for (i, (&row_a, &col_a)) in coo_a
320            .row_indices
321            .iter()
322            .zip(coo_a.col_indices.iter())
323            .enumerate()
324        {
325            let val_a = coo_a.values.data()[i];
326            for (j, (&row_b, &col_b)) in coo_b
327                .row_indices
328                .iter()
329                .zip(coo_b.col_indices.iter())
330                .enumerate()
331            {
332                if col_a == row_b {
333                    let val_b = coo_b.values.data()[j];
334                    *result_map.entry((row_a, col_b)).or_insert(0.0) += val_a * val_b;
335                }
336            }
337        }
338
339        // 转换回 COO 格式
340        let mut row_indices = Vec::new();
341        let mut col_indices = Vec::new();
342        let mut values_data = Vec::new();
343
344        let mut entries: Vec<_> = result_map.into_iter().collect();
345        entries.sort_by_key(|&(pos, _)| pos);
346
347        for ((row, col), val) in entries {
348            row_indices.push(row);
349            col_indices.push(col);
350            values_data.push(val);
351        }
352
353        let values = DenseTensor::new(values_data.clone(), vec![values_data.len()]);
354        Ok(SparseTensor::COO(COOTensor::new(
355            row_indices,
356            col_indices,
357            values,
358            [rows_a, cols_b],
359        )))
360    }
361}
362
363#[cfg(feature = "tensor-sparse")]
364impl SparseTensorOps for SparseTensor {
365    fn nnz(&self) -> usize {
366        match self {
367            SparseTensor::COO(coo) => coo.nnz(),
368            SparseTensor::CSR(csr) => csr.nnz(),
369        }
370    }
371
372    fn coo(&self) -> COOView<'_> {
373        self.coo_view()
374    }
375
376    fn row_indices(&self) -> &[usize] {
377        match self {
378            SparseTensor::COO(coo) => coo.row_indices(),
379            SparseTensor::CSR(_) => &[],
380        }
381    }
382
383    fn col_indices(&self) -> &[usize] {
384        match self {
385            SparseTensor::COO(coo) => coo.col_indices(),
386            SparseTensor::CSR(csr) => csr.col_indices(),
387        }
388    }
389
390    fn values(&self) -> &DenseTensor {
391        match self {
392            SparseTensor::COO(coo) => coo.values(),
393            SparseTensor::CSR(csr) => csr.values(),
394        }
395    }
396}
397
398#[cfg(feature = "tensor-sparse")]
399impl TensorBase for SparseTensor {
400    fn shape(&self) -> &[usize] {
401        match self {
402            SparseTensor::COO(coo) => &coo.shape[..],
403            SparseTensor::CSR(csr) => &csr.shape[..],
404        }
405    }
406
407    fn dtype(&self) -> DType {
408        DType::F64
409    }
410
411    fn device(&self) -> Device {
412        Device::Cpu
413    }
414
415    fn to_dense(&self) -> DenseTensor {
416        let shape = self.shape();
417        let rows = shape[0];
418        let cols = shape[1];
419        let mut data = vec![0.0; rows * cols];
420        let coo = self.to_coo();
421
422        for (i, (&row, &col)) in coo
423            .row_indices
424            .iter()
425            .zip(coo.col_indices.iter())
426            .enumerate()
427        {
428            let val = coo.values.data()[i];
429            data[row * cols + col] = val;
430        }
431
432        DenseTensor::new(data, vec![rows, cols])
433    }
434
435    #[cfg(feature = "tensor-sparse")]
436    fn to_sparse(&self) -> Option<SparseTensor> {
437        Some(self.clone())
438    }
439}
440
441#[cfg(feature = "tensor-sparse")]
442impl fmt::Debug for SparseTensor {
443    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
444        let shape = self.shape();
445        let rows = shape[0];
446        let cols = shape[1];
447        f.debug_struct("SparseTensor")
448            .field("shape", &[rows, cols])
449            .field("nnz", &self.nnz())
450            .field("sparsity", &self.sparsity())
451            .finish()
452    }
453}
454
455/// COO 张量实现
456impl COOTensor {
457    /// 获取形状
458    pub fn shape_array(&self) -> [usize; 2] {
459        self.shape
460    }
461}
462
463/// CSR 张量实现
464impl CSRTensor {
465    /// 获取形状
466    pub fn shape_array(&self) -> [usize; 2] {
467        self.shape
468    }
469
470    /// 获取指定行的非零元素
471    pub fn row(&self, row: usize) -> Option<Vec<(usize, f64)>> {
472        if row >= self.shape[0] {
473            return None;
474        }
475
476        let start = self.row_offsets[row];
477        let end = self.row_offsets[row + 1];
478
479        if start == end {
480            return Some(Vec::new());
481        }
482
483        let mut result = Vec::with_capacity(end - start);
484        for i in start..end {
485            result.push((self.col_indices[i], self.values.data()[i]));
486        }
487        Some(result)
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_coo_creation() {
497        let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
498        let coo = SparseTensor::from_edges(&edges, [3, 3]);
499
500        assert_eq!(coo.nnz(), 4);
501        assert_eq!(coo.shape(), &[3, 3]);
502    }
503
504    #[test]
505    fn test_coo_to_csr() {
506        let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
507        let coo = SparseTensor::from_edges(&edges, [3, 3]);
508        let csr = coo.to_csr();
509
510        assert_eq!(csr.nnz(), 4);
511        assert_eq!(csr.row_offsets(), &[0, 2, 3, 4]);
512    }
513
514    #[test]
515    fn test_sparse_dense_conversion() {
516        let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
517        let sparse = SparseTensor::from_edges(&edges, [3, 3]);
518        let dense = sparse.to_dense();
519
520        assert_eq!(dense.shape(), &[3, 3]);
521        assert_eq!(dense.get(&[0, 1]).unwrap(), 1.0);
522        assert_eq!(dense.get(&[0, 2]).unwrap(), 2.0);
523        assert_eq!(dense.get(&[2, 0]).unwrap(), 4.0);
524    }
525
526    #[test]
527    fn test_spmv() {
528        let edges = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
529        let sparse = SparseTensor::from_edges(&edges, [2, 2]);
530        let x = DenseTensor::new(vec![1.0, 2.0], vec![2]);
531
532        let result = sparse.spmv(&x).unwrap();
533        // [1,2; 3,4] * [1; 2] = [1*1+2*2; 3*1+4*2] = [5; 11]
534        assert_eq!(result.data(), &[5.0, 11.0]);
535    }
536
537    #[test]
538    fn test_spmm() {
539        let edges_a = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
540        let a = SparseTensor::from_edges(&edges_a, [2, 2]);
541
542        let edges_b = vec![(0, 0, 5.0), (0, 1, 6.0), (1, 0, 7.0), (1, 1, 8.0)];
543        let b = SparseTensor::from_edges(&edges_b, [2, 2]);
544
545        let result = a.spmm(&b).unwrap();
546        let result_dense = result.to_dense();
547
548        // [1,2; 3,4] * [5,6; 7,8] = [19,22; 43,50]
549        assert_eq!(result_dense.get(&[0, 0]).unwrap(), 19.0);
550        assert_eq!(result_dense.get(&[0, 1]).unwrap(), 22.0);
551        assert_eq!(result_dense.get(&[1, 0]).unwrap(), 43.0);
552        assert_eq!(result_dense.get(&[1, 1]).unwrap(), 50.0);
553    }
554}