Skip to main content

god_graph/tensor/
sparse.rs

1//! 稀疏张量实现
2//!
3//! 提供 COO(Coordinate)和 CSR(Compressed Sparse Row)格式
4//! 用于高效的图神经网络计算
5
6use core::fmt;
7
8use crate::tensor::dense::DenseTensor;
9use crate::tensor::error::TensorError;
10use crate::tensor::traits::{COOView, DType, Device, SparseTensorOps, TensorBase};
11
12/// COO(Coordinate)格式稀疏张量
13#[derive(Debug, Clone)]
14pub struct COOTensor {
15    row_indices: Vec<usize>,
16    col_indices: Vec<usize>,
17    values: DenseTensor,
18    shape: [usize; 2],
19}
20
21impl COOTensor {
22    /// 创建新的 COO 张量
23    pub fn new(
24        row_indices: Vec<usize>,
25        col_indices: Vec<usize>,
26        values: DenseTensor,
27        shape: [usize; 2],
28    ) -> Self {
29        assert_eq!(
30            row_indices.len(),
31            col_indices.len(),
32            "Row and column indices must have the same length"
33        );
34        assert_eq!(
35            row_indices.len(),
36            values.numel(),
37            "Indices length must match values length"
38        );
39        Self {
40            row_indices,
41            col_indices,
42            values,
43            shape,
44        }
45    }
46
47    /// 获取非零元素数量
48    pub fn nnz(&self) -> usize {
49        self.values.numel()
50    }
51
52    /// 从边列表创建 COO 张量
53    pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
54        let row_indices: Vec<usize> = edges.iter().map(|&(r, _, _)| r).collect();
55        let col_indices: Vec<usize> = edges.iter().map(|&(_, c, _)| c).collect();
56        let values_data: Vec<f64> = edges.iter().map(|&(_, _, v)| v).collect();
57        let values = DenseTensor::new(values_data, vec![edges.len()]);
58        Self::new(row_indices, col_indices, values, shape)
59    }
60
61    /// 获取行索引
62    pub fn row_indices(&self) -> &[usize] {
63        &self.row_indices
64    }
65
66    /// 获取列索引
67    pub fn col_indices(&self) -> &[usize] {
68        &self.col_indices
69    }
70
71    /// 获取值
72    pub fn values(&self) -> &DenseTensor {
73        &self.values
74    }
75}
76
77/// CSR(Compressed Sparse Row)格式稀疏张量
78#[derive(Debug, Clone)]
79pub struct CSRTensor {
80    row_offsets: Vec<usize>,
81    col_indices: Vec<usize>,
82    values: DenseTensor,
83    shape: [usize; 2],
84}
85
86#[cfg(feature = "tensor")]
87impl CSRTensor {
88    /// 创建新的 CSR 张量
89    pub fn new(
90        row_offsets: Vec<usize>,
91        col_indices: Vec<usize>,
92        values: DenseTensor,
93        shape: [usize; 2],
94    ) -> Self {
95        assert_eq!(
96            col_indices.len(),
97            values.numel(),
98            "Column indices length must match values length"
99        );
100        Self {
101            row_offsets,
102            col_indices,
103            values,
104            shape,
105        }
106    }
107
108    /// 获取非零元素数量
109    pub fn nnz(&self) -> usize {
110        self.values.numel()
111    }
112
113    /// 从 COO 张量转换为 CSR
114    pub fn from_coo(coo: &COOTensor) -> Self {
115        let mut row_offsets = vec![0; coo.shape[0] + 1];
116        let mut col_indices = vec![0; coo.nnz()];
117        let mut values_data = vec![0.0; coo.nnz()];
118
119        // 计算每行的非零元素数量
120        for &row in &coo.row_indices {
121            row_offsets[row + 1] += 1;
122        }
123
124        // 转换为偏移量
125        for i in 1..row_offsets.len() {
126            row_offsets[i] += row_offsets[i - 1];
127        }
128
129        // 填充列索引和值
130        let mut row_pos = row_offsets.clone();
131        for (i, (&row, &col)) in coo
132            .row_indices
133            .iter()
134            .zip(coo.col_indices.iter())
135            .enumerate()
136        {
137            let pos = row_pos[row];
138            col_indices[pos] = col;
139            values_data[pos] = coo.values.data()[i];
140            row_pos[row] += 1;
141        }
142
143        let values = DenseTensor::new(values_data, vec![coo.nnz()]);
144        Self::new(row_offsets, col_indices, values, coo.shape)
145    }
146
147    /// 获取行偏移量
148    pub fn row_offsets(&self) -> &[usize] {
149        &self.row_offsets
150    }
151
152    /// 获取列索引
153    pub fn col_indices(&self) -> &[usize] {
154        &self.col_indices
155    }
156
157    /// 获取值
158    pub fn values(&self) -> &DenseTensor {
159        &self.values
160    }
161}
162
163/// 稀疏张量枚举:支持多种稀疏格式
164#[derive(Clone)]
165pub enum SparseTensor {
166    /// COO(Coordinate)格式
167    COO(COOTensor),
168    /// CSR(Compressed Sparse Row)格式
169    CSR(CSRTensor),
170}
171
172#[cfg(feature = "tensor")]
173impl SparseTensor {
174    /// 创建 COO 格式稀疏张量
175    pub fn coo(
176        row_indices: Vec<usize>,
177        col_indices: Vec<usize>,
178        values: DenseTensor,
179        shape: [usize; 2],
180    ) -> Self {
181        SparseTensor::COO(COOTensor::new(row_indices, col_indices, values, shape))
182    }
183
184    /// 创建 CSR 格式稀疏张量
185    pub fn csr(
186        row_offsets: Vec<usize>,
187        col_indices: Vec<usize>,
188        values: DenseTensor,
189        shape: [usize; 2],
190    ) -> Self {
191        SparseTensor::CSR(CSRTensor::new(row_offsets, col_indices, values, shape))
192    }
193
194    /// 获取非零元素数量
195    pub fn nnz(&self) -> usize {
196        match self {
197            SparseTensor::COO(coo) => coo.nnz(),
198            SparseTensor::CSR(csr) => csr.nnz(),
199        }
200    }
201
202    /// 转换为 CSR 格式
203    pub fn to_csr(&self) -> CSRTensor {
204        match self {
205            SparseTensor::COO(coo) => CSRTensor::from_coo(coo),
206            SparseTensor::CSR(csr) => csr.clone(),
207        }
208    }
209
210    /// 转换为 COO 格式
211    pub fn to_coo(&self) -> COOTensor {
212        match self {
213            SparseTensor::COO(coo) => coo.clone(),
214            SparseTensor::CSR(csr) => {
215                // CSR 转 COO
216                let mut row_indices = Vec::with_capacity(csr.nnz());
217                let col_indices = csr.col_indices.clone();
218                let mut values_data = Vec::with_capacity(csr.nnz());
219
220                for row in 0..csr.shape[0] {
221                    let start = csr.row_offsets[row];
222                    let end = csr.row_offsets[row + 1];
223                    for _ in start..end {
224                        row_indices.push(row);
225                    }
226                    for i in start..end {
227                        values_data.push(csr.values.data()[i]);
228                    }
229                }
230
231                let values = DenseTensor::new(values_data, vec![csr.nnz()]);
232                COOTensor::new(row_indices, col_indices, values, csr.shape)
233            }
234        }
235    }
236
237    /// 获取 COO 视图
238    pub fn coo_view(&self) -> COOView<'_> {
239        match self {
240            SparseTensor::COO(coo) => COOView::new(
241                &coo.row_indices,
242                &coo.col_indices,
243                coo.values.data(),
244                coo.shape,
245            ),
246            SparseTensor::CSR(_) => {
247                // For CSR, we need to convert to COO first, but we can't return a view
248                // So we return an empty view as a workaround (this is a limitation)
249                COOView::new(&[], &[], &[], [0, 0])
250            }
251        }
252    }
253
254    /// 从边列表创建稀疏张量(COO 格式)
255    pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
256        SparseTensor::COO(COOTensor::from_edges(edges, shape))
257    }
258
259    /// 稀疏矩阵 - 稠密向量乘法
260    pub fn spmv(&self, x: &DenseTensor) -> Result<DenseTensor, TensorError> {
261        if self.ndim() != 2 {
262            return Err(TensorError::DimensionMismatch {
263                expected: 2,
264                got: self.ndim(),
265            });
266        }
267
268        let shape = self.shape();
269        let rows = shape[0];
270        let cols = shape[1];
271
272        if x.shape() != [cols] {
273            return Err(TensorError::ShapeMismatch {
274                expected: vec![cols],
275                got: x.shape().to_vec(),
276            });
277        }
278
279        let mut result = vec![0.0; rows];
280        let coo = self.to_coo();
281
282        for (i, (&row, &col)) in coo
283            .row_indices
284            .iter()
285            .zip(coo.col_indices.iter())
286            .enumerate()
287        {
288            let val = coo.values.data()[i];
289            let x_val = x.data()[col];
290            result[row] += val * x_val;
291        }
292
293        Ok(DenseTensor::new(result, vec![rows]))
294    }
295
296    /// 稀疏矩阵 - 稀疏矩阵乘法
297    pub fn spmm(&self, other: &Self) -> Result<Self, TensorError> {
298        let shape_a = self.shape();
299        let shape_b = other.shape();
300        let (rows_a, cols_a) = (shape_a[0], shape_a[1]);
301        let (rows_b, cols_b) = (shape_b[0], shape_b[1]);
302
303        if cols_a != rows_b {
304            return Err(TensorError::ShapeMismatch {
305                expected: vec![cols_a],
306                got: vec![rows_b],
307            });
308        }
309
310        // 转换为 COO 进行乘法
311        let coo_a = self.to_coo();
312        let coo_b = other.to_coo();
313
314        // 使用哈希表累加结果
315        use std::collections::HashMap;
316        let mut result_map: HashMap<(usize, usize), f64> = HashMap::new();
317
318        for (i, (&row_a, &col_a)) in coo_a
319            .row_indices
320            .iter()
321            .zip(coo_a.col_indices.iter())
322            .enumerate()
323        {
324            let val_a = coo_a.values.data()[i];
325            for (j, (&row_b, &col_b)) in coo_b
326                .row_indices
327                .iter()
328                .zip(coo_b.col_indices.iter())
329                .enumerate()
330            {
331                if col_a == row_b {
332                    let val_b = coo_b.values.data()[j];
333                    *result_map.entry((row_a, col_b)).or_insert(0.0) += val_a * val_b;
334                }
335            }
336        }
337
338        // 转换回 COO 格式
339        let mut row_indices = Vec::new();
340        let mut col_indices = Vec::new();
341        let mut values_data = Vec::new();
342
343        let mut entries: Vec<_> = result_map.into_iter().collect();
344        entries.sort_by_key(|&(pos, _)| pos);
345
346        for ((row, col), val) in entries {
347            row_indices.push(row);
348            col_indices.push(col);
349            values_data.push(val);
350        }
351
352        let values = DenseTensor::new(values_data.clone(), vec![values_data.len()]);
353        Ok(SparseTensor::COO(COOTensor::new(
354            row_indices,
355            col_indices,
356            values,
357            [rows_a, cols_b],
358        )))
359    }
360}
361
362#[cfg(feature = "tensor")]
363impl SparseTensorOps for SparseTensor {
364    fn nnz(&self) -> usize {
365        match self {
366            SparseTensor::COO(coo) => coo.nnz(),
367            SparseTensor::CSR(csr) => csr.nnz(),
368        }
369    }
370
371    fn coo(&self) -> COOView<'_> {
372        self.coo_view()
373    }
374
375    fn row_indices(&self) -> &[usize] {
376        match self {
377            SparseTensor::COO(coo) => coo.row_indices(),
378            SparseTensor::CSR(_) => &[],
379        }
380    }
381
382    fn col_indices(&self) -> &[usize] {
383        match self {
384            SparseTensor::COO(coo) => coo.col_indices(),
385            SparseTensor::CSR(csr) => csr.col_indices(),
386        }
387    }
388
389    fn values(&self) -> &DenseTensor {
390        match self {
391            SparseTensor::COO(coo) => coo.values(),
392            SparseTensor::CSR(csr) => csr.values(),
393        }
394    }
395}
396
397#[cfg(feature = "tensor")]
398impl TensorBase for SparseTensor {
399    fn shape(&self) -> &[usize] {
400        match self {
401            SparseTensor::COO(coo) => &coo.shape[..],
402            SparseTensor::CSR(csr) => &csr.shape[..],
403        }
404    }
405
406    fn dtype(&self) -> DType {
407        DType::F64
408    }
409
410    fn device(&self) -> Device {
411        Device::Cpu
412    }
413
414    fn to_dense(&self) -> DenseTensor {
415        let shape = self.shape();
416        let rows = shape[0];
417        let cols = shape[1];
418        let mut data = vec![0.0; rows * cols];
419        let coo = self.to_coo();
420
421        for (i, (&row, &col)) in coo
422            .row_indices
423            .iter()
424            .zip(coo.col_indices.iter())
425            .enumerate()
426        {
427            let val = coo.values.data()[i];
428            data[row * cols + col] = val;
429        }
430
431        DenseTensor::new(data, vec![rows, cols])
432    }
433
434    #[cfg(feature = "tensor")]
435    fn to_sparse(&self) -> Option<SparseTensor> {
436        Some(self.clone())
437    }
438}
439
440#[cfg(feature = "tensor")]
441impl fmt::Debug for SparseTensor {
442    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
443        let shape = self.shape();
444        let rows = shape[0];
445        let cols = shape[1];
446        f.debug_struct("SparseTensor")
447            .field("shape", &[rows, cols])
448            .field("nnz", &self.nnz())
449            .field("sparsity", &self.sparsity())
450            .finish()
451    }
452}
453
454/// COO 张量实现
455impl COOTensor {
456    /// 获取形状
457    pub fn shape_array(&self) -> [usize; 2] {
458        self.shape
459    }
460}
461
462/// CSR 张量实现
463impl CSRTensor {
464    /// 获取形状
465    pub fn shape_array(&self) -> [usize; 2] {
466        self.shape
467    }
468
469    /// 获取指定行的非零元素
470    pub fn row(&self, row: usize) -> Option<Vec<(usize, f64)>> {
471        if row >= self.shape[0] {
472            return None;
473        }
474
475        let start = self.row_offsets[row];
476        let end = self.row_offsets[row + 1];
477
478        if start == end {
479            return Some(Vec::new());
480        }
481
482        let mut result = Vec::with_capacity(end - start);
483        for i in start..end {
484            result.push((self.col_indices[i], self.values.data()[i]));
485        }
486        Some(result)
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493
494    #[test]
495    fn test_coo_creation() {
496        let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
497        let coo = SparseTensor::from_edges(&edges, [3, 3]);
498
499        assert_eq!(coo.nnz(), 4);
500        assert_eq!(coo.shape(), &[3, 3]);
501    }
502
503    #[test]
504    fn test_coo_to_csr() {
505        let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
506        let coo = SparseTensor::from_edges(&edges, [3, 3]);
507        let csr = coo.to_csr();
508
509        assert_eq!(csr.nnz(), 4);
510        assert_eq!(csr.row_offsets(), &[0, 2, 3, 4]);
511    }
512
513    #[test]
514    fn test_sparse_dense_conversion() {
515        let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
516        let sparse = SparseTensor::from_edges(&edges, [3, 3]);
517        let dense = sparse.to_dense();
518
519        assert_eq!(dense.shape(), &[3, 3]);
520        assert_eq!(dense.get(&[0, 1]).unwrap(), 1.0);
521        assert_eq!(dense.get(&[0, 2]).unwrap(), 2.0);
522        assert_eq!(dense.get(&[2, 0]).unwrap(), 4.0);
523    }
524
525    #[test]
526    fn test_spmv() {
527        let edges = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
528        let sparse = SparseTensor::from_edges(&edges, [2, 2]);
529        let x = DenseTensor::new(vec![1.0, 2.0], vec![2]);
530
531        let result = sparse.spmv(&x).unwrap();
532        // [1,2; 3,4] * [1; 2] = [1*1+2*2; 3*1+4*2] = [5; 11]
533        assert_eq!(result.data(), &[5.0, 11.0]);
534    }
535
536    #[test]
537    fn test_spmm() {
538        let edges_a = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
539        let a = SparseTensor::from_edges(&edges_a, [2, 2]);
540
541        let edges_b = vec![(0, 0, 5.0), (0, 1, 6.0), (1, 0, 7.0), (1, 1, 8.0)];
542        let b = SparseTensor::from_edges(&edges_b, [2, 2]);
543
544        let result = a.spmm(&b).unwrap();
545        let result_dense = result.to_dense();
546
547        // [1,2; 3,4] * [5,6; 7,8] = [19,22; 43,50]
548        assert_eq!(result_dense.get(&[0, 0]).unwrap(), 19.0);
549        assert_eq!(result_dense.get(&[0, 1]).unwrap(), 22.0);
550        assert_eq!(result_dense.get(&[1, 0]).unwrap(), 43.0);
551        assert_eq!(result_dense.get(&[1, 1]).unwrap(), 50.0);
552    }
553}