1use core::fmt;
7
8use crate::tensor::dense::DenseTensor;
9use crate::tensor::error::TensorError;
10use crate::tensor::traits::{COOView, DType, Device, SparseTensorOps, TensorBase};
11
12#[derive(Debug, Clone)]
14pub struct COOTensor {
15 row_indices: Vec<usize>,
16 col_indices: Vec<usize>,
17 values: DenseTensor,
18 shape: [usize; 2],
19}
20
21impl COOTensor {
22 pub fn new(
24 row_indices: Vec<usize>,
25 col_indices: Vec<usize>,
26 values: DenseTensor,
27 shape: [usize; 2],
28 ) -> Self {
29 assert_eq!(
30 row_indices.len(),
31 col_indices.len(),
32 "Row and column indices must have the same length"
33 );
34 assert_eq!(
35 row_indices.len(),
36 values.numel(),
37 "Indices length must match values length"
38 );
39 Self {
40 row_indices,
41 col_indices,
42 values,
43 shape,
44 }
45 }
46
47 pub fn nnz(&self) -> usize {
49 self.values.numel()
50 }
51
52 pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
54 let row_indices: Vec<usize> = edges.iter().map(|&(r, _, _)| r).collect();
55 let col_indices: Vec<usize> = edges.iter().map(|&(_, c, _)| c).collect();
56 let values_data: Vec<f64> = edges.iter().map(|&(_, _, v)| v).collect();
57 let values = DenseTensor::new(values_data, vec![edges.len()]);
58 Self::new(row_indices, col_indices, values, shape)
59 }
60
61 pub fn row_indices(&self) -> &[usize] {
63 &self.row_indices
64 }
65
66 pub fn col_indices(&self) -> &[usize] {
68 &self.col_indices
69 }
70
71 pub fn values(&self) -> &DenseTensor {
73 &self.values
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct CSRTensor {
80 row_offsets: Vec<usize>,
81 col_indices: Vec<usize>,
82 values: DenseTensor,
83 shape: [usize; 2],
84}
85
86#[cfg(feature = "tensor")]
87impl CSRTensor {
88 pub fn new(
90 row_offsets: Vec<usize>,
91 col_indices: Vec<usize>,
92 values: DenseTensor,
93 shape: [usize; 2],
94 ) -> Self {
95 assert_eq!(
96 col_indices.len(),
97 values.numel(),
98 "Column indices length must match values length"
99 );
100 Self {
101 row_offsets,
102 col_indices,
103 values,
104 shape,
105 }
106 }
107
108 pub fn nnz(&self) -> usize {
110 self.values.numel()
111 }
112
113 pub fn from_coo(coo: &COOTensor) -> Self {
115 let mut row_offsets = vec![0; coo.shape[0] + 1];
116 let mut col_indices = vec![0; coo.nnz()];
117 let mut values_data = vec![0.0; coo.nnz()];
118
119 for &row in &coo.row_indices {
121 row_offsets[row + 1] += 1;
122 }
123
124 for i in 1..row_offsets.len() {
126 row_offsets[i] += row_offsets[i - 1];
127 }
128
129 let mut row_pos = row_offsets.clone();
131 for (i, (&row, &col)) in coo
132 .row_indices
133 .iter()
134 .zip(coo.col_indices.iter())
135 .enumerate()
136 {
137 let pos = row_pos[row];
138 col_indices[pos] = col;
139 values_data[pos] = coo.values.data()[i];
140 row_pos[row] += 1;
141 }
142
143 let values = DenseTensor::new(values_data, vec![coo.nnz()]);
144 Self::new(row_offsets, col_indices, values, coo.shape)
145 }
146
147 pub fn row_offsets(&self) -> &[usize] {
149 &self.row_offsets
150 }
151
152 pub fn col_indices(&self) -> &[usize] {
154 &self.col_indices
155 }
156
157 pub fn values(&self) -> &DenseTensor {
159 &self.values
160 }
161}
162
163#[derive(Clone)]
165pub enum SparseTensor {
166 COO(COOTensor),
168 CSR(CSRTensor),
170}
171
172#[cfg(feature = "tensor")]
173impl SparseTensor {
174 pub fn coo(
176 row_indices: Vec<usize>,
177 col_indices: Vec<usize>,
178 values: DenseTensor,
179 shape: [usize; 2],
180 ) -> Self {
181 SparseTensor::COO(COOTensor::new(row_indices, col_indices, values, shape))
182 }
183
184 pub fn csr(
186 row_offsets: Vec<usize>,
187 col_indices: Vec<usize>,
188 values: DenseTensor,
189 shape: [usize; 2],
190 ) -> Self {
191 SparseTensor::CSR(CSRTensor::new(row_offsets, col_indices, values, shape))
192 }
193
194 pub fn nnz(&self) -> usize {
196 match self {
197 SparseTensor::COO(coo) => coo.nnz(),
198 SparseTensor::CSR(csr) => csr.nnz(),
199 }
200 }
201
202 pub fn to_csr(&self) -> CSRTensor {
204 match self {
205 SparseTensor::COO(coo) => CSRTensor::from_coo(coo),
206 SparseTensor::CSR(csr) => csr.clone(),
207 }
208 }
209
210 pub fn to_coo(&self) -> COOTensor {
212 match self {
213 SparseTensor::COO(coo) => coo.clone(),
214 SparseTensor::CSR(csr) => {
215 let mut row_indices = Vec::with_capacity(csr.nnz());
217 let col_indices = csr.col_indices.clone();
218 let mut values_data = Vec::with_capacity(csr.nnz());
219
220 for row in 0..csr.shape[0] {
221 let start = csr.row_offsets[row];
222 let end = csr.row_offsets[row + 1];
223 for _ in start..end {
224 row_indices.push(row);
225 }
226 for i in start..end {
227 values_data.push(csr.values.data()[i]);
228 }
229 }
230
231 let values = DenseTensor::new(values_data, vec![csr.nnz()]);
232 COOTensor::new(row_indices, col_indices, values, csr.shape)
233 }
234 }
235 }
236
237 pub fn coo_view(&self) -> COOView<'_> {
239 match self {
240 SparseTensor::COO(coo) => COOView::new(
241 &coo.row_indices,
242 &coo.col_indices,
243 coo.values.data(),
244 coo.shape,
245 ),
246 SparseTensor::CSR(_) => {
247 COOView::new(&[], &[], &[], [0, 0])
250 }
251 }
252 }
253
254 pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
256 SparseTensor::COO(COOTensor::from_edges(edges, shape))
257 }
258
259 pub fn spmv(&self, x: &DenseTensor) -> Result<DenseTensor, TensorError> {
261 if self.ndim() != 2 {
262 return Err(TensorError::DimensionMismatch {
263 expected: 2,
264 got: self.ndim(),
265 });
266 }
267
268 let shape = self.shape();
269 let rows = shape[0];
270 let cols = shape[1];
271
272 if x.shape() != [cols] {
273 return Err(TensorError::ShapeMismatch {
274 expected: vec![cols],
275 got: x.shape().to_vec(),
276 });
277 }
278
279 let mut result = vec![0.0; rows];
280 let coo = self.to_coo();
281
282 for (i, (&row, &col)) in coo
283 .row_indices
284 .iter()
285 .zip(coo.col_indices.iter())
286 .enumerate()
287 {
288 let val = coo.values.data()[i];
289 let x_val = x.data()[col];
290 result[row] += val * x_val;
291 }
292
293 Ok(DenseTensor::new(result, vec![rows]))
294 }
295
296 pub fn spmm(&self, other: &Self) -> Result<Self, TensorError> {
298 let shape_a = self.shape();
299 let shape_b = other.shape();
300 let (rows_a, cols_a) = (shape_a[0], shape_a[1]);
301 let (rows_b, cols_b) = (shape_b[0], shape_b[1]);
302
303 if cols_a != rows_b {
304 return Err(TensorError::ShapeMismatch {
305 expected: vec![cols_a],
306 got: vec![rows_b],
307 });
308 }
309
310 let coo_a = self.to_coo();
312 let coo_b = other.to_coo();
313
314 use std::collections::HashMap;
316 let mut result_map: HashMap<(usize, usize), f64> = HashMap::new();
317
318 for (i, (&row_a, &col_a)) in coo_a
319 .row_indices
320 .iter()
321 .zip(coo_a.col_indices.iter())
322 .enumerate()
323 {
324 let val_a = coo_a.values.data()[i];
325 for (j, (&row_b, &col_b)) in coo_b
326 .row_indices
327 .iter()
328 .zip(coo_b.col_indices.iter())
329 .enumerate()
330 {
331 if col_a == row_b {
332 let val_b = coo_b.values.data()[j];
333 *result_map.entry((row_a, col_b)).or_insert(0.0) += val_a * val_b;
334 }
335 }
336 }
337
338 let mut row_indices = Vec::new();
340 let mut col_indices = Vec::new();
341 let mut values_data = Vec::new();
342
343 let mut entries: Vec<_> = result_map.into_iter().collect();
344 entries.sort_by_key(|&(pos, _)| pos);
345
346 for ((row, col), val) in entries {
347 row_indices.push(row);
348 col_indices.push(col);
349 values_data.push(val);
350 }
351
352 let values = DenseTensor::new(values_data.clone(), vec![values_data.len()]);
353 Ok(SparseTensor::COO(COOTensor::new(
354 row_indices,
355 col_indices,
356 values,
357 [rows_a, cols_b],
358 )))
359 }
360}
361
362#[cfg(feature = "tensor")]
363impl SparseTensorOps for SparseTensor {
364 fn nnz(&self) -> usize {
365 match self {
366 SparseTensor::COO(coo) => coo.nnz(),
367 SparseTensor::CSR(csr) => csr.nnz(),
368 }
369 }
370
371 fn coo(&self) -> COOView<'_> {
372 self.coo_view()
373 }
374
375 fn row_indices(&self) -> &[usize] {
376 match self {
377 SparseTensor::COO(coo) => coo.row_indices(),
378 SparseTensor::CSR(_) => &[],
379 }
380 }
381
382 fn col_indices(&self) -> &[usize] {
383 match self {
384 SparseTensor::COO(coo) => coo.col_indices(),
385 SparseTensor::CSR(csr) => csr.col_indices(),
386 }
387 }
388
389 fn values(&self) -> &DenseTensor {
390 match self {
391 SparseTensor::COO(coo) => coo.values(),
392 SparseTensor::CSR(csr) => csr.values(),
393 }
394 }
395}
396
397#[cfg(feature = "tensor")]
398impl TensorBase for SparseTensor {
399 fn shape(&self) -> &[usize] {
400 match self {
401 SparseTensor::COO(coo) => &coo.shape[..],
402 SparseTensor::CSR(csr) => &csr.shape[..],
403 }
404 }
405
406 fn dtype(&self) -> DType {
407 DType::F64
408 }
409
410 fn device(&self) -> Device {
411 Device::Cpu
412 }
413
414 fn to_dense(&self) -> DenseTensor {
415 let shape = self.shape();
416 let rows = shape[0];
417 let cols = shape[1];
418 let mut data = vec![0.0; rows * cols];
419 let coo = self.to_coo();
420
421 for (i, (&row, &col)) in coo
422 .row_indices
423 .iter()
424 .zip(coo.col_indices.iter())
425 .enumerate()
426 {
427 let val = coo.values.data()[i];
428 data[row * cols + col] = val;
429 }
430
431 DenseTensor::new(data, vec![rows, cols])
432 }
433
434 #[cfg(feature = "tensor")]
435 fn to_sparse(&self) -> Option<SparseTensor> {
436 Some(self.clone())
437 }
438}
439
440#[cfg(feature = "tensor")]
441impl fmt::Debug for SparseTensor {
442 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
443 let shape = self.shape();
444 let rows = shape[0];
445 let cols = shape[1];
446 f.debug_struct("SparseTensor")
447 .field("shape", &[rows, cols])
448 .field("nnz", &self.nnz())
449 .field("sparsity", &self.sparsity())
450 .finish()
451 }
452}
453
454impl COOTensor {
456 pub fn shape_array(&self) -> [usize; 2] {
458 self.shape
459 }
460}
461
462impl CSRTensor {
464 pub fn shape_array(&self) -> [usize; 2] {
466 self.shape
467 }
468
469 pub fn row(&self, row: usize) -> Option<Vec<(usize, f64)>> {
471 if row >= self.shape[0] {
472 return None;
473 }
474
475 let start = self.row_offsets[row];
476 let end = self.row_offsets[row + 1];
477
478 if start == end {
479 return Some(Vec::new());
480 }
481
482 let mut result = Vec::with_capacity(end - start);
483 for i in start..end {
484 result.push((self.col_indices[i], self.values.data()[i]));
485 }
486 Some(result)
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
495 fn test_coo_creation() {
496 let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
497 let coo = SparseTensor::from_edges(&edges, [3, 3]);
498
499 assert_eq!(coo.nnz(), 4);
500 assert_eq!(coo.shape(), &[3, 3]);
501 }
502
503 #[test]
504 fn test_coo_to_csr() {
505 let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
506 let coo = SparseTensor::from_edges(&edges, [3, 3]);
507 let csr = coo.to_csr();
508
509 assert_eq!(csr.nnz(), 4);
510 assert_eq!(csr.row_offsets(), &[0, 2, 3, 4]);
511 }
512
513 #[test]
514 fn test_sparse_dense_conversion() {
515 let edges = vec![(0, 1, 1.0), (0, 2, 2.0), (1, 2, 3.0), (2, 0, 4.0)];
516 let sparse = SparseTensor::from_edges(&edges, [3, 3]);
517 let dense = sparse.to_dense();
518
519 assert_eq!(dense.shape(), &[3, 3]);
520 assert_eq!(dense.get(&[0, 1]).unwrap(), 1.0);
521 assert_eq!(dense.get(&[0, 2]).unwrap(), 2.0);
522 assert_eq!(dense.get(&[2, 0]).unwrap(), 4.0);
523 }
524
525 #[test]
526 fn test_spmv() {
527 let edges = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
528 let sparse = SparseTensor::from_edges(&edges, [2, 2]);
529 let x = DenseTensor::new(vec![1.0, 2.0], vec![2]);
530
531 let result = sparse.spmv(&x).unwrap();
532 assert_eq!(result.data(), &[5.0, 11.0]);
534 }
535
536 #[test]
537 fn test_spmm() {
538 let edges_a = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
539 let a = SparseTensor::from_edges(&edges_a, [2, 2]);
540
541 let edges_b = vec![(0, 0, 5.0), (0, 1, 6.0), (1, 0, 7.0), (1, 1, 8.0)];
542 let b = SparseTensor::from_edges(&edges_b, [2, 2]);
543
544 let result = a.spmm(&b).unwrap();
545 let result_dense = result.to_dense();
546
547 assert_eq!(result_dense.get(&[0, 0]).unwrap(), 19.0);
549 assert_eq!(result_dense.get(&[0, 1]).unwrap(), 22.0);
550 assert_eq!(result_dense.get(&[1, 0]).unwrap(), 43.0);
551 assert_eq!(result_dense.get(&[1, 1]).unwrap(), 50.0);
552 }
553}