1use core::fmt;
7
8use crate::tensor::dense::DenseTensor;
9use crate::tensor::error::TensorError;
10use crate::tensor::traits::{COOView, DType, Device, SparseTensorOps, TensorBase};
11
12#[derive(Debug, Clone)]
14pub struct COOTensor {
15 row_indices: Vec<usize>,
16 col_indices: Vec<usize>,
17 values: DenseTensor,
18 shape: [usize; 2],
19}
20
21#[cfg(feature = "tensor-sparse")]
22impl COOTensor {
23 pub fn new(
25 row_indices: Vec<usize>,
26 col_indices: Vec<usize>,
27 values: DenseTensor,
28 shape: [usize; 2],
29 ) -> Self {
30 assert_eq!(
31 row_indices.len(),
32 col_indices.len(),
33 "Row and column indices must have the same length"
34 );
35 assert_eq!(
36 row_indices.len(),
37 values.numel(),
38 "Indices length must match values length"
39 );
40 Self {
41 row_indices,
42 col_indices,
43 values,
44 shape,
45 }
46 }
47
48 pub fn nnz(&self) -> usize {
50 self.values.numel()
51 }
52
53 pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
55 let row_indices: Vec<usize> = edges.iter().map(|&(r, _, _)| r).collect();
56 let col_indices: Vec<usize> = edges.iter().map(|&(_, c, _)| c).collect();
57 let values_data: Vec<f64> = edges.iter().map(|&(_, _, v)| v).collect();
58 let values = DenseTensor::new(values_data, vec![edges.len()]);
59 Self::new(row_indices, col_indices, values, shape)
60 }
61
62 pub fn row_indices(&self) -> &[usize] {
64 &self.row_indices
65 }
66
67 pub fn col_indices(&self) -> &[usize] {
69 &self.col_indices
70 }
71
72 pub fn values(&self) -> &DenseTensor {
74 &self.values
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct CSRTensor {
81 row_offsets: Vec<usize>,
82 col_indices: Vec<usize>,
83 values: DenseTensor,
84 shape: [usize; 2],
85}
86
87#[cfg(feature = "tensor-sparse")]
88impl CSRTensor {
89 pub fn new(
91 row_offsets: Vec<usize>,
92 col_indices: Vec<usize>,
93 values: DenseTensor,
94 shape: [usize; 2],
95 ) -> Self {
96 assert_eq!(
97 col_indices.len(),
98 values.numel(),
99 "Column indices length must match values length"
100 );
101 Self {
102 row_offsets,
103 col_indices,
104 values,
105 shape,
106 }
107 }
108
109 pub fn nnz(&self) -> usize {
111 self.values.numel()
112 }
113
114 pub fn from_coo(coo: &COOTensor) -> Self {
116 let mut row_offsets = vec![0; coo.shape[0] + 1];
117 let mut col_indices = vec![0; coo.nnz()];
118 let mut values_data = vec![0.0; coo.nnz()];
119
120 for &row in &coo.row_indices {
122 row_offsets[row + 1] += 1;
123 }
124
125 for i in 1..row_offsets.len() {
127 row_offsets[i] += row_offsets[i - 1];
128 }
129
130 let mut row_pos = row_offsets.clone();
132 for (i, (&row, &col)) in coo
133 .row_indices
134 .iter()
135 .zip(coo.col_indices.iter())
136 .enumerate()
137 {
138 let pos = row_pos[row];
139 col_indices[pos] = col;
140 values_data[pos] = coo.values.data()[i];
141 row_pos[row] += 1;
142 }
143
144 let values = DenseTensor::new(values_data, vec![coo.nnz()]);
145 Self::new(row_offsets, col_indices, values, coo.shape)
146 }
147
148 pub fn row_offsets(&self) -> &[usize] {
150 &self.row_offsets
151 }
152
153 pub fn col_indices(&self) -> &[usize] {
155 &self.col_indices
156 }
157
158 pub fn values(&self) -> &DenseTensor {
160 &self.values
161 }
162}
163
164#[derive(Clone)]
166pub enum SparseTensor {
167 COO(COOTensor),
169 CSR(CSRTensor),
171}
172
173#[cfg(feature = "tensor-sparse")]
174impl SparseTensor {
175 pub fn coo(
177 row_indices: Vec<usize>,
178 col_indices: Vec<usize>,
179 values: DenseTensor,
180 shape: [usize; 2],
181 ) -> Self {
182 SparseTensor::COO(COOTensor::new(row_indices, col_indices, values, shape))
183 }
184
185 pub fn csr(
187 row_offsets: Vec<usize>,
188 col_indices: Vec<usize>,
189 values: DenseTensor,
190 shape: [usize; 2],
191 ) -> Self {
192 SparseTensor::CSR(CSRTensor::new(row_offsets, col_indices, values, shape))
193 }
194
195 pub fn nnz(&self) -> usize {
197 match self {
198 SparseTensor::COO(coo) => coo.nnz(),
199 SparseTensor::CSR(csr) => csr.nnz(),
200 }
201 }
202
203 pub fn to_csr(&self) -> CSRTensor {
205 match self {
206 SparseTensor::COO(coo) => CSRTensor::from_coo(coo),
207 SparseTensor::CSR(csr) => csr.clone(),
208 }
209 }
210
211 pub fn to_coo(&self) -> COOTensor {
213 match self {
214 SparseTensor::COO(coo) => coo.clone(),
215 SparseTensor::CSR(csr) => {
216 let mut row_indices = Vec::with_capacity(csr.nnz());
218 let col_indices = csr.col_indices.clone();
219 let mut values_data = Vec::with_capacity(csr.nnz());
220
221 for row in 0..csr.shape[0] {
222 let start = csr.row_offsets[row];
223 let end = csr.row_offsets[row + 1];
224 for _ in start..end {
225 row_indices.push(row);
226 }
227 for i in start..end {
228 values_data.push(csr.values.data()[i]);
229 }
230 }
231
232 let values = DenseTensor::new(values_data, vec![csr.nnz()]);
233 COOTensor::new(row_indices, col_indices, values, csr.shape)
234 }
235 }
236 }
237
238 pub fn coo_view(&self) -> COOView<'_> {
240 match self {
241 SparseTensor::COO(coo) => COOView::new(
242 &coo.row_indices,
243 &coo.col_indices,
244 coo.values.data(),
245 coo.shape,
246 ),
247 SparseTensor::CSR(_) => {
248 COOView::new(&[], &[], &[], [0, 0])
251 }
252 }
253 }
254
255 pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
257 SparseTensor::COO(COOTensor::from_edges(edges, shape))
258 }
259
260 pub fn spmv(&self, x: &DenseTensor) -> Result<DenseTensor, TensorError> {
262 if self.ndim() != 2 {
263 return Err(TensorError::DimensionMismatch {
264 expected: 2,
265 got: self.ndim(),
266 });
267 }
268
269 let shape = self.shape();
270 let rows = shape[0];
271 let cols = shape[1];
272
273 if x.shape() != [cols] {
274 return Err(TensorError::ShapeMismatch {
275 expected: vec![cols],
276 got: x.shape().to_vec(),
277 });
278 }
279
280 let mut result = vec![0.0; rows];
281 let coo = self.to_coo();
282
283 for (i, (&row, &col)) in coo
284 .row_indices
285 .iter()
286 .zip(coo.col_indices.iter())
287 .enumerate()
288 {
289 let val = coo.values.data()[i];
290 let x_val = x.data()[col];
291 result[row] += val * x_val;
292 }
293
294 Ok(DenseTensor::new(result, vec![rows]))
295 }
296
297 pub fn spmm(&self, other: &Self) -> Result<Self, TensorError> {
299 let shape_a = self.shape();
300 let shape_b = other.shape();
301 let (rows_a, cols_a) = (shape_a[0], shape_a[1]);
302 let (rows_b, cols_b) = (shape_b[0], shape_b[1]);
303
304 if cols_a != rows_b {
305 return Err(TensorError::ShapeMismatch {
306 expected: vec![cols_a],
307 got: vec![rows_b],
308 });
309 }
310
311 let coo_a = self.to_coo();
313 let coo_b = other.to_coo();
314
315 use std::collections::HashMap;
317 let mut result_map: HashMap<(usize, usize), f64> = HashMap::new();
318
319 for (i, (&row_a, &col_a)) in coo_a
320 .row_indices
321 .iter()
322 .zip(coo_a.col_indices.iter())
323 .enumerate()
324 {
325 let val_a = coo_a.values.data()[i];
326 for (j, (&row_b, &col_b)) in coo_b
327 .row_indices
328 .iter()
329 .zip(coo_b.col_indices.iter())
330 .enumerate()
331 {
332 if col_a == row_b {
333 let val_b = coo_b.values.data()[j];
334 *result_map.entry((row_a, col_b)).or_insert(0.0) += val_a * val_b;
335 }
336 }
337 }
338
339 let mut row_indices = Vec::new();
341 let mut col_indices = Vec::new();
342 let mut values_data = Vec::new();
343
344 let mut entries: Vec<_> = result_map.into_iter().collect();
345 entries.sort_by_key(|&(pos, _)| pos);
346
347 for ((row, col), val) in entries {
348 row_indices.push(row);
349 col_indices.push(col);
350 values_data.push(val);
351 }
352
353 let values = DenseTensor::new(values_data.clone(), vec![values_data.len()]);
354 Ok(SparseTensor::COO(COOTensor::new(
355 row_indices,
356 col_indices,
357 values,
358 [rows_a, cols_b],
359 )))
360 }
361}
362
363#[cfg(feature = "tensor-sparse")]
364impl SparseTensorOps for SparseTensor {
365 fn nnz(&self) -> usize {
366 match self {
367 SparseTensor::COO(coo) => coo.nnz(),
368 SparseTensor::CSR(csr) => csr.nnz(),
369 }
370 }
371
372 fn coo(&self) -> COOView<'_> {
373 self.coo_view()
374 }
375
376 fn row_indices(&self) -> &[usize] {
377 match self {
378 SparseTensor::COO(coo) => coo.row_indices(),
379 SparseTensor::CSR(_) => &[],
380 }
381 }
382
383 fn col_indices(&self) -> &[usize] {
384 match self {
385 SparseTensor::COO(coo) => coo.col_indices(),
386 SparseTensor::CSR(csr) => csr.col_indices(),
387 }
388 }
389
390 fn values(&self) -> &DenseTensor {
391 match self {
392 SparseTensor::COO(coo) => coo.values(),
393 SparseTensor::CSR(csr) => csr.values(),
394 }
395 }
396}
397
398#[cfg(feature = "tensor-sparse")]
399impl TensorBase for SparseTensor {
400 fn shape(&self) -> &[usize] {
401 match self {
402 SparseTensor::COO(coo) => &coo.shape[..],
403 SparseTensor::CSR(csr) => &csr.shape[..],
404 }
405 }
406
407 fn dtype(&self) -> DType {
408 DType::F64
409 }
410
411 fn device(&self) -> Device {
412 Device::Cpu
413 }
414
415 fn to_dense(&self) -> DenseTensor {
416 let shape = self.shape();
417 let rows = shape[0];
418 let cols = shape[1];
419 let mut data = vec![0.0; rows * cols];
420 let coo = self.to_coo();
421
422 for (i, (&row, &col)) in coo
423 .row_indices
424 .iter()
425 .zip(coo.col_indices.iter())
426 .enumerate()
427 {
428 let val = coo.values.data()[i];
429 data[row * cols + col] = val;
430 }
431
432 DenseTensor::new(data, vec![rows, cols])
433 }
434
435 #[cfg(feature = "tensor-sparse")]
436 fn to_sparse(&self) -> Option<SparseTensor> {
437 Some(self.clone())
438 }
439}
440
441#[cfg(feature = "tensor-sparse")]
442impl fmt::Debug for SparseTensor {
443 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
444 let shape = self.shape();
445 let rows = shape[0];
446 let cols = shape[1];
447 f.debug_struct("SparseTensor")
448 .field("shape", &[rows, cols])
449 .field("nnz", &self.nnz())
450 .field("sparsity", &self.sparsity())
451 .finish()
452 }
453}
454
455impl COOTensor {
457 pub fn shape_array(&self) -> [usize; 2] {
459 self.shape
460 }
461}
462
463impl CSRTensor {
465 pub fn shape_array(&self) -> [usize; 2] {
467 self.shape
468 }
469
470 pub fn row(&self, row: usize) -> Option<Vec<(usize, f64)>> {
472 if row >= self.shape[0] {
473 return None;
474 }
475
476 let start = self.row_offsets[row];
477 let end = self.row_offsets[row + 1];
478
479 if start == end {
480 return Some(Vec::new());
481 }
482
483 let mut result = Vec::with_capacity(end - start);
484 for i in start..end {
485 result.push((self.col_indices[i], self.values.data()[i]));
486 }
487 Some(result)
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_coo_creation() {
497 let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
498 let coo = SparseTensor::from_edges(&edges, [3, 3]);
499
500 assert_eq!(coo.nnz(), 4);
501 assert_eq!(coo.shape(), &[3, 3]);
502 }
503
504 #[test]
505 fn test_coo_to_csr() {
506 let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
507 let coo = SparseTensor::from_edges(&edges, [3, 3]);
508 let csr = coo.to_csr();
509
510 assert_eq!(csr.nnz(), 4);
511 assert_eq!(csr.row_offsets(), &[0, 2, 3, 4]);
512 }
513
514 #[test]
515 fn test_sparse_dense_conversion() {
516 let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
517 let sparse = SparseTensor::from_edges(&edges, [3, 3]);
518 let dense = sparse.to_dense();
519
520 assert_eq!(dense.shape(), &[3, 3]);
521 assert_eq!(dense.get(&[0, 1]).unwrap(), 1.0);
522 assert_eq!(dense.get(&[0, 2]).unwrap(), 2.0);
523 assert_eq!(dense.get(&[2, 0]).unwrap(), 4.0);
524 }
525
526 #[test]
527 fn test_spmv() {
528 let edges = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
529 let sparse = SparseTensor::from_edges(&edges, [2, 2]);
530 let x = DenseTensor::new(vec![1.0, 2.0], vec![2]);
531
532 let result = sparse.spmv(&x).unwrap();
533 assert_eq!(result.data(), &[5.0, 11.0]);
535 }
536
537 #[test]
538 fn test_spmm() {
539 let edges_a = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
540 let a = SparseTensor::from_edges(&edges_a, [2, 2]);
541
542 let edges_b = vec![(0, 0, 5.0), (0, 1, 6.0), (1, 0, 7.0), (1, 1, 8.0)];
543 let b = SparseTensor::from_edges(&edges_b, [2, 2]);
544
545 let result = a.spmm(&b).unwrap();
546 let result_dense = result.to_dense();
547
548 assert_eq!(result_dense.get(&[0, 0]).unwrap(), 19.0);
550 assert_eq!(result_dense.get(&[0, 1]).unwrap(), 22.0);
551 assert_eq!(result_dense.get(&[1, 0]).unwrap(), 43.0);
552 assert_eq!(result_dense.get(&[1, 1]).unwrap(), 50.0);
553 }
554}