Skip to main content

god_gragh/tensor/
sparse.rs

1//! 稀疏张量实现
2//!
3//! 提供 COO(Coordinate)和 CSR(Compressed Sparse Row)格式
4//! 用于高效的图神经网络计算
5
6use core::fmt;
7
8use crate::tensor::traits::{TensorBase, DType, Device, COOView, SparseTensorOps};
9use crate::tensor::dense::DenseTensor;
10use crate::tensor::error::TensorError;
11
12/// COO(Coordinate)格式稀疏张量
13#[derive(Debug, Clone)]
14pub struct COOTensor {
15    row_indices: Vec<usize>,
16    col_indices: Vec<usize>,
17    values: DenseTensor,
18    shape: [usize; 2],
19}
20
21#[cfg(feature = "tensor-sparse")]
22impl COOTensor {
23    /// 创建新的 COO 张量
24    pub fn new(row_indices: Vec<usize>, col_indices: Vec<usize>, values: DenseTensor, shape: [usize; 2]) -> Self {
25        assert_eq!(
26            row_indices.len(),
27            col_indices.len(),
28            "Row and column indices must have the same length"
29        );
30        assert_eq!(
31            row_indices.len(),
32            values.numel(),
33            "Indices length must match values length"
34        );
35        Self {
36            row_indices,
37            col_indices,
38            values,
39            shape,
40        }
41    }
42
43    /// 获取非零元素数量
44    pub fn nnz(&self) -> usize {
45        self.values.numel()
46    }
47
48    /// 从边列表创建 COO 张量
49    pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
50        let row_indices: Vec<usize> = edges.iter().map(|&(r, _, _)| r).collect();
51        let col_indices: Vec<usize> = edges.iter().map(|&(_, c, _)| c).collect();
52        let values_data: Vec<f64> = edges.iter().map(|&(_, _, v)| v).collect();
53        let values = DenseTensor::new(values_data, vec![edges.len()]);
54        Self::new(row_indices, col_indices, values, shape)
55    }
56
57    /// 获取行索引
58    pub fn row_indices(&self) -> &[usize] {
59        &self.row_indices
60    }
61
62    /// 获取列索引
63    pub fn col_indices(&self) -> &[usize] {
64        &self.col_indices
65    }
66
67    /// 获取值
68    pub fn values(&self) -> &DenseTensor {
69        &self.values
70    }
71}
72
73/// CSR(Compressed Sparse Row)格式稀疏张量
74#[derive(Debug, Clone)]
75pub struct CSRTensor {
76    row_offsets: Vec<usize>,
77    col_indices: Vec<usize>,
78    values: DenseTensor,
79    shape: [usize; 2],
80}
81
82#[cfg(feature = "tensor-sparse")]
83impl CSRTensor {
84    /// 创建新的 CSR 张量
85    pub fn new(row_offsets: Vec<usize>, col_indices: Vec<usize>, values: DenseTensor, shape: [usize; 2]) -> Self {
86        assert_eq!(
87            col_indices.len(),
88            values.numel(),
89            "Column indices length must match values length"
90        );
91        Self {
92            row_offsets,
93            col_indices,
94            values,
95            shape,
96        }
97    }
98
99    /// 获取非零元素数量
100    pub fn nnz(&self) -> usize {
101        self.values.numel()
102    }
103
104    /// 从 COO 张量转换为 CSR
105    pub fn from_coo(coo: &COOTensor) -> Self {
106        let mut row_offsets = vec![0; coo.shape[0] + 1];
107        let mut col_indices = vec![0; coo.nnz()];
108        let mut values_data = vec![0.0; coo.nnz()];
109
110        // 计算每行的非零元素数量
111        for &row in &coo.row_indices {
112            row_offsets[row + 1] += 1;
113        }
114
115        // 转换为偏移量
116        for i in 1..row_offsets.len() {
117            row_offsets[i] += row_offsets[i - 1];
118        }
119
120        // 填充列索引和值
121        let mut row_pos = row_offsets.clone();
122        for (i, (&row, &col)) in coo.row_indices.iter().zip(coo.col_indices.iter()).enumerate() {
123            let pos = row_pos[row];
124            col_indices[pos] = col;
125            values_data[pos] = coo.values.data()[i];
126            row_pos[row] += 1;
127        }
128
129        let values = DenseTensor::new(values_data, vec![coo.nnz()]);
130        Self::new(row_offsets, col_indices, values, coo.shape)
131    }
132
133    /// 获取行偏移量
134    pub fn row_offsets(&self) -> &[usize] {
135        &self.row_offsets
136    }
137
138    /// 获取列索引
139    pub fn col_indices(&self) -> &[usize] {
140        &self.col_indices
141    }
142
143    /// 获取值
144    pub fn values(&self) -> &DenseTensor {
145        &self.values
146    }
147}
148
149/// 稀疏张量枚举:支持多种稀疏格式
150#[derive(Clone)]
151pub enum SparseTensor {
152    /// COO(Coordinate)格式
153    COO(COOTensor),
154    /// CSR(Compressed Sparse Row)格式
155    CSR(CSRTensor),
156}
157
158#[cfg(feature = "tensor-sparse")]
159impl SparseTensor {
160    /// 创建 COO 格式稀疏张量
161    pub fn coo(row_indices: Vec<usize>, col_indices: Vec<usize>, values: DenseTensor, shape: [usize; 2]) -> Self {
162        SparseTensor::COO(COOTensor::new(row_indices, col_indices, values, shape))
163    }
164
165    /// 创建 CSR 格式稀疏张量
166    pub fn csr(row_offsets: Vec<usize>, col_indices: Vec<usize>, values: DenseTensor, shape: [usize; 2]) -> Self {
167        SparseTensor::CSR(CSRTensor::new(row_offsets, col_indices, values, shape))
168    }
169
170    /// 获取非零元素数量
171    pub fn nnz(&self) -> usize {
172        match self {
173            SparseTensor::COO(coo) => coo.nnz(),
174            SparseTensor::CSR(csr) => csr.nnz(),
175        }
176    }
177
178    /// 转换为 CSR 格式
179    pub fn to_csr(&self) -> CSRTensor {
180        match self {
181            SparseTensor::COO(coo) => CSRTensor::from_coo(coo),
182            SparseTensor::CSR(csr) => csr.clone(),
183        }
184    }
185
186    /// 转换为 COO 格式
187    pub fn to_coo(&self) -> COOTensor {
188        match self {
189            SparseTensor::COO(coo) => coo.clone(),
190            SparseTensor::CSR(csr) => {
191                // CSR 转 COO
192                let mut row_indices = Vec::with_capacity(csr.nnz());
193                let col_indices = csr.col_indices.clone();
194                let mut values_data = Vec::with_capacity(csr.nnz());
195
196                for row in 0..csr.shape[0] {
197                    let start = csr.row_offsets[row];
198                    let end = csr.row_offsets[row + 1];
199                    for _ in start..end {
200                        row_indices.push(row);
201                    }
202                    for i in start..end {
203                        values_data.push(csr.values.data()[i]);
204                    }
205                }
206
207                let values = DenseTensor::new(values_data, vec![csr.nnz()]);
208                COOTensor::new(row_indices, col_indices, values, csr.shape)
209            }
210        }
211    }
212
213    /// 获取 COO 视图
214    pub fn coo_view(&self) -> COOView<'_> {
215        match self {
216            SparseTensor::COO(coo) => {
217                COOView::new(&coo.row_indices, &coo.col_indices, coo.values.data(), coo.shape)
218            }
219            SparseTensor::CSR(_) => {
220                // For CSR, we need to convert to COO first, but we can't return a view
221                // So we return an empty view as a workaround (this is a limitation)
222                COOView::new(&[], &[], &[], [0, 0])
223            }
224        }
225    }
226
227    /// 从边列表创建稀疏张量(COO 格式)
228    pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
229        SparseTensor::COO(COOTensor::from_edges(edges, shape))
230    }
231
232    /// 稀疏矩阵 - 稠密向量乘法
233    pub fn spmv(&self, x: &DenseTensor) -> Result<DenseTensor, TensorError> {
234        if self.ndim() != 2 {
235            return Err(TensorError::DimensionMismatch {
236                expected: 2,
237                got: self.ndim(),
238            });
239        }
240
241        let shape = self.shape();
242        let rows = shape[0];
243        let cols = shape[1];
244
245        if x.shape() != [cols] {
246            return Err(TensorError::ShapeMismatch {
247                expected: vec![cols],
248                got: x.shape().to_vec(),
249            });
250        }
251
252        let mut result = vec![0.0; rows];
253        let coo = self.to_coo();
254
255        for (i, (&row, &col)) in coo.row_indices.iter().zip(coo.col_indices.iter()).enumerate() {
256            let val = coo.values.data()[i];
257            let x_val = x.data()[col];
258            result[row] += val * x_val;
259        }
260
261        Ok(DenseTensor::new(result, vec![rows]))
262    }
263
264    /// 稀疏矩阵 - 稀疏矩阵乘法
265    pub fn spmm(&self, other: &Self) -> Result<Self, TensorError> {
266        let shape_a = self.shape();
267        let shape_b = other.shape();
268        let (rows_a, cols_a) = (shape_a[0], shape_a[1]);
269        let (rows_b, cols_b) = (shape_b[0], shape_b[1]);
270
271        if cols_a != rows_b {
272            return Err(TensorError::ShapeMismatch {
273                expected: vec![cols_a],
274                got: vec![rows_b],
275            });
276        }
277
278        // 转换为 COO 进行乘法
279        let coo_a = self.to_coo();
280        let coo_b = other.to_coo();
281
282        // 使用哈希表累加结果
283        use std::collections::HashMap;
284        let mut result_map: HashMap<(usize, usize), f64> = HashMap::new();
285
286        for (i, (&row_a, &col_a)) in coo_a.row_indices.iter().zip(coo_a.col_indices.iter()).enumerate() {
287            let val_a = coo_a.values.data()[i];
288            for (j, (&row_b, &col_b)) in coo_b.row_indices.iter().zip(coo_b.col_indices.iter()).enumerate() {
289                if col_a == row_b {
290                    let val_b = coo_b.values.data()[j];
291                    *result_map.entry((row_a, col_b)).or_insert(0.0) += val_a * val_b;
292                }
293            }
294        }
295
296        // 转换回 COO 格式
297        let mut row_indices = Vec::new();
298        let mut col_indices = Vec::new();
299        let mut values_data = Vec::new();
300
301        let mut entries: Vec<_> = result_map.into_iter().collect();
302        entries.sort_by_key(|&(pos, _)| pos);
303
304        for ((row, col), val) in entries {
305            row_indices.push(row);
306            col_indices.push(col);
307            values_data.push(val);
308        }
309
310        let values = DenseTensor::new(values_data.clone(), vec![values_data.len()]);
311        Ok(SparseTensor::COO(COOTensor::new(row_indices, col_indices, values, [rows_a, cols_b])))
312    }
313}
314
315#[cfg(feature = "tensor-sparse")]
316impl SparseTensorOps for SparseTensor {
317    fn nnz(&self) -> usize {
318        match self {
319            SparseTensor::COO(coo) => coo.nnz(),
320            SparseTensor::CSR(csr) => csr.nnz(),
321        }
322    }
323
324    fn coo(&self) -> COOView<'_> {
325        self.coo_view()
326    }
327
328    fn row_indices(&self) -> &[usize] {
329        match self {
330            SparseTensor::COO(coo) => coo.row_indices(),
331            SparseTensor::CSR(_) => &[],
332        }
333    }
334
335    fn col_indices(&self) -> &[usize] {
336        match self {
337            SparseTensor::COO(coo) => coo.col_indices(),
338            SparseTensor::CSR(csr) => csr.col_indices(),
339        }
340    }
341
342    fn values(&self) -> &DenseTensor {
343        match self {
344            SparseTensor::COO(coo) => coo.values(),
345            SparseTensor::CSR(csr) => csr.values(),
346        }
347    }
348}
349
350#[cfg(feature = "tensor-sparse")]
351impl TensorBase for SparseTensor {
352    fn shape(&self) -> &[usize] {
353        match self {
354            SparseTensor::COO(coo) => &coo.shape[..],
355            SparseTensor::CSR(csr) => &csr.shape[..],
356        }
357    }
358
359    fn dtype(&self) -> DType {
360        DType::F64
361    }
362
363    fn device(&self) -> Device {
364        Device::Cpu
365    }
366
367    fn to_dense(&self) -> DenseTensor {
368        let shape = self.shape();
369        let rows = shape[0];
370        let cols = shape[1];
371        let mut data = vec![0.0; rows * cols];
372        let coo = self.to_coo();
373
374        for (i, (&row, &col)) in coo.row_indices.iter().zip(coo.col_indices.iter()).enumerate() {
375            let val = coo.values.data()[i];
376            data[row * cols + col] = val;
377        }
378
379        DenseTensor::new(data, vec![rows, cols])
380    }
381
382    #[cfg(feature = "tensor-sparse")]
383    fn to_sparse(&self) -> Option<SparseTensor> {
384        Some(self.clone())
385    }
386}
387
388#[cfg(feature = "tensor-sparse")]
389impl fmt::Debug for SparseTensor {
390    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391        let shape = self.shape();
392        let rows = shape[0];
393        let cols = shape[1];
394        f.debug_struct("SparseTensor")
395            .field("shape", &[rows, cols])
396            .field("nnz", &self.nnz())
397            .field("sparsity", &self.sparsity())
398            .finish()
399    }
400}
401
402/// COO 张量实现
403impl COOTensor {
404    /// 获取形状
405    pub fn shape_array(&self) -> [usize; 2] {
406        self.shape
407    }
408}
409
410/// CSR 张量实现
411impl CSRTensor {
412    /// 获取形状
413    pub fn shape_array(&self) -> [usize; 2] {
414        self.shape
415    }
416
417    /// 获取指定行的非零元素
418    pub fn row(&self, row: usize) -> Option<Vec<(usize, f64)>> {
419        if row >= self.shape[0] {
420            return None;
421        }
422
423        let start = self.row_offsets[row];
424        let end = self.row_offsets[row + 1];
425
426        if start == end {
427            return Some(Vec::new());
428        }
429
430        let mut result = Vec::with_capacity(end - start);
431        for i in start..end {
432            result.push((self.col_indices[i], self.values.data()[i]));
433        }
434        Some(result)
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_coo_creation() {
444        let edges = vec![
445            (0, 1, 1.0),
446            (0, 2, 2.0),
447            (1, 2, 3.0),
448            (2, 0, 4.0),
449        ];
450        let coo = SparseTensor::from_edges(&edges, [3, 3]);
451
452        assert_eq!(coo.nnz(), 4);
453        assert_eq!(coo.shape(), &[3, 3]);
454    }
455
456    #[test]
457    fn test_coo_to_csr() {
458        let edges = vec![
459            (0, 1, 1.0),
460            (0, 2, 2.0),
461            (1, 2, 3.0),
462            (2, 0, 4.0),
463        ];
464        let coo = SparseTensor::from_edges(&edges, [3, 3]);
465        let csr = coo.to_csr();
466
467        assert_eq!(csr.nnz(), 4);
468        assert_eq!(csr.row_offsets(), &[0, 2, 3, 4]);
469    }
470
471    #[test]
472    fn test_sparse_dense_conversion() {
473        let edges = vec![
474            (0, 1, 1.0),
475            (0, 2, 2.0),
476            (1, 2, 3.0),
477            (2, 0, 4.0),
478        ];
479        let sparse = SparseTensor::from_edges(&edges, [3, 3]);
480        let dense = sparse.to_dense();
481
482        assert_eq!(dense.shape(), &[3, 3]);
483        assert_eq!(dense.get(&[0, 1]).unwrap(), 1.0);
484        assert_eq!(dense.get(&[0, 2]).unwrap(), 2.0);
485        assert_eq!(dense.get(&[2, 0]).unwrap(), 4.0);
486    }
487
488    #[test]
489    fn test_spmv() {
490        let edges = vec![
491            (0, 0, 1.0),
492            (0, 1, 2.0),
493            (1, 0, 3.0),
494            (1, 1, 4.0),
495        ];
496        let sparse = SparseTensor::from_edges(&edges, [2, 2]);
497        let x = DenseTensor::new(vec![1.0, 2.0], vec![2]);
498
499        let result = sparse.spmv(&x).unwrap();
500        // [1,2; 3,4] * [1; 2] = [1*1+2*2; 3*1+4*2] = [5; 11]
501        assert_eq!(result.data(), &[5.0, 11.0]);
502    }
503
504    #[test]
505    fn test_spmm() {
506        let edges_a = vec![
507            (0, 0, 1.0),
508            (0, 1, 2.0),
509            (1, 0, 3.0),
510            (1, 1, 4.0),
511        ];
512        let a = SparseTensor::from_edges(&edges_a, [2, 2]);
513
514        let edges_b = vec![
515            (0, 0, 5.0),
516            (0, 1, 6.0),
517            (1, 0, 7.0),
518            (1, 1, 8.0),
519        ];
520        let b = SparseTensor::from_edges(&edges_b, [2, 2]);
521
522        let result = a.spmm(&b).unwrap();
523        let result_dense = result.to_dense();
524
525        // [1,2; 3,4] * [5,6; 7,8] = [19,22; 43,50]
526        assert_eq!(result_dense.get(&[0, 0]).unwrap(), 19.0);
527        assert_eq!(result_dense.get(&[0, 1]).unwrap(), 22.0);
528        assert_eq!(result_dense.get(&[1, 0]).unwrap(), 43.0);
529        assert_eq!(result_dense.get(&[1, 1]).unwrap(), 50.0);
530    }
531}