1use core::fmt;
7
8use crate::tensor::traits::{TensorBase, DType, Device, COOView, SparseTensorOps};
9use crate::tensor::dense::DenseTensor;
10use crate::tensor::error::TensorError;
11
12#[derive(Debug, Clone)]
14pub struct COOTensor {
15 row_indices: Vec<usize>,
16 col_indices: Vec<usize>,
17 values: DenseTensor,
18 shape: [usize; 2],
19}
20
21#[cfg(feature = "tensor-sparse")]
22impl COOTensor {
23 pub fn new(row_indices: Vec<usize>, col_indices: Vec<usize>, values: DenseTensor, shape: [usize; 2]) -> Self {
25 assert_eq!(
26 row_indices.len(),
27 col_indices.len(),
28 "Row and column indices must have the same length"
29 );
30 assert_eq!(
31 row_indices.len(),
32 values.numel(),
33 "Indices length must match values length"
34 );
35 Self {
36 row_indices,
37 col_indices,
38 values,
39 shape,
40 }
41 }
42
43 pub fn nnz(&self) -> usize {
45 self.values.numel()
46 }
47
48 pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
50 let row_indices: Vec<usize> = edges.iter().map(|&(r, _, _)| r).collect();
51 let col_indices: Vec<usize> = edges.iter().map(|&(_, c, _)| c).collect();
52 let values_data: Vec<f64> = edges.iter().map(|&(_, _, v)| v).collect();
53 let values = DenseTensor::new(values_data, vec![edges.len()]);
54 Self::new(row_indices, col_indices, values, shape)
55 }
56
57 pub fn row_indices(&self) -> &[usize] {
59 &self.row_indices
60 }
61
62 pub fn col_indices(&self) -> &[usize] {
64 &self.col_indices
65 }
66
67 pub fn values(&self) -> &DenseTensor {
69 &self.values
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct CSRTensor {
76 row_offsets: Vec<usize>,
77 col_indices: Vec<usize>,
78 values: DenseTensor,
79 shape: [usize; 2],
80}
81
82#[cfg(feature = "tensor-sparse")]
83impl CSRTensor {
84 pub fn new(row_offsets: Vec<usize>, col_indices: Vec<usize>, values: DenseTensor, shape: [usize; 2]) -> Self {
86 assert_eq!(
87 col_indices.len(),
88 values.numel(),
89 "Column indices length must match values length"
90 );
91 Self {
92 row_offsets,
93 col_indices,
94 values,
95 shape,
96 }
97 }
98
99 pub fn nnz(&self) -> usize {
101 self.values.numel()
102 }
103
104 pub fn from_coo(coo: &COOTensor) -> Self {
106 let mut row_offsets = vec![0; coo.shape[0] + 1];
107 let mut col_indices = vec![0; coo.nnz()];
108 let mut values_data = vec![0.0; coo.nnz()];
109
110 for &row in &coo.row_indices {
112 row_offsets[row + 1] += 1;
113 }
114
115 for i in 1..row_offsets.len() {
117 row_offsets[i] += row_offsets[i - 1];
118 }
119
120 let mut row_pos = row_offsets.clone();
122 for (i, (&row, &col)) in coo.row_indices.iter().zip(coo.col_indices.iter()).enumerate() {
123 let pos = row_pos[row];
124 col_indices[pos] = col;
125 values_data[pos] = coo.values.data()[i];
126 row_pos[row] += 1;
127 }
128
129 let values = DenseTensor::new(values_data, vec![coo.nnz()]);
130 Self::new(row_offsets, col_indices, values, coo.shape)
131 }
132
133 pub fn row_offsets(&self) -> &[usize] {
135 &self.row_offsets
136 }
137
138 pub fn col_indices(&self) -> &[usize] {
140 &self.col_indices
141 }
142
143 pub fn values(&self) -> &DenseTensor {
145 &self.values
146 }
147}
148
149#[derive(Clone)]
151pub enum SparseTensor {
152 COO(COOTensor),
154 CSR(CSRTensor),
156}
157
158#[cfg(feature = "tensor-sparse")]
159impl SparseTensor {
160 pub fn coo(row_indices: Vec<usize>, col_indices: Vec<usize>, values: DenseTensor, shape: [usize; 2]) -> Self {
162 SparseTensor::COO(COOTensor::new(row_indices, col_indices, values, shape))
163 }
164
165 pub fn csr(row_offsets: Vec<usize>, col_indices: Vec<usize>, values: DenseTensor, shape: [usize; 2]) -> Self {
167 SparseTensor::CSR(CSRTensor::new(row_offsets, col_indices, values, shape))
168 }
169
170 pub fn nnz(&self) -> usize {
172 match self {
173 SparseTensor::COO(coo) => coo.nnz(),
174 SparseTensor::CSR(csr) => csr.nnz(),
175 }
176 }
177
178 pub fn to_csr(&self) -> CSRTensor {
180 match self {
181 SparseTensor::COO(coo) => CSRTensor::from_coo(coo),
182 SparseTensor::CSR(csr) => csr.clone(),
183 }
184 }
185
186 pub fn to_coo(&self) -> COOTensor {
188 match self {
189 SparseTensor::COO(coo) => coo.clone(),
190 SparseTensor::CSR(csr) => {
191 let mut row_indices = Vec::with_capacity(csr.nnz());
193 let col_indices = csr.col_indices.clone();
194 let mut values_data = Vec::with_capacity(csr.nnz());
195
196 for row in 0..csr.shape[0] {
197 let start = csr.row_offsets[row];
198 let end = csr.row_offsets[row + 1];
199 for _ in start..end {
200 row_indices.push(row);
201 }
202 for i in start..end {
203 values_data.push(csr.values.data()[i]);
204 }
205 }
206
207 let values = DenseTensor::new(values_data, vec![csr.nnz()]);
208 COOTensor::new(row_indices, col_indices, values, csr.shape)
209 }
210 }
211 }
212
213 pub fn coo_view(&self) -> COOView<'_> {
215 match self {
216 SparseTensor::COO(coo) => {
217 COOView::new(&coo.row_indices, &coo.col_indices, coo.values.data(), coo.shape)
218 }
219 SparseTensor::CSR(_) => {
220 COOView::new(&[], &[], &[], [0, 0])
223 }
224 }
225 }
226
227 pub fn from_edges(edges: &[(usize, usize, f64)], shape: [usize; 2]) -> Self {
229 SparseTensor::COO(COOTensor::from_edges(edges, shape))
230 }
231
232 pub fn spmv(&self, x: &DenseTensor) -> Result<DenseTensor, TensorError> {
234 if self.ndim() != 2 {
235 return Err(TensorError::DimensionMismatch {
236 expected: 2,
237 got: self.ndim(),
238 });
239 }
240
241 let shape = self.shape();
242 let rows = shape[0];
243 let cols = shape[1];
244
245 if x.shape() != [cols] {
246 return Err(TensorError::ShapeMismatch {
247 expected: vec![cols],
248 got: x.shape().to_vec(),
249 });
250 }
251
252 let mut result = vec![0.0; rows];
253 let coo = self.to_coo();
254
255 for (i, (&row, &col)) in coo.row_indices.iter().zip(coo.col_indices.iter()).enumerate() {
256 let val = coo.values.data()[i];
257 let x_val = x.data()[col];
258 result[row] += val * x_val;
259 }
260
261 Ok(DenseTensor::new(result, vec![rows]))
262 }
263
264 pub fn spmm(&self, other: &Self) -> Result<Self, TensorError> {
266 let shape_a = self.shape();
267 let shape_b = other.shape();
268 let (rows_a, cols_a) = (shape_a[0], shape_a[1]);
269 let (rows_b, cols_b) = (shape_b[0], shape_b[1]);
270
271 if cols_a != rows_b {
272 return Err(TensorError::ShapeMismatch {
273 expected: vec![cols_a],
274 got: vec![rows_b],
275 });
276 }
277
278 let coo_a = self.to_coo();
280 let coo_b = other.to_coo();
281
282 use std::collections::HashMap;
284 let mut result_map: HashMap<(usize, usize), f64> = HashMap::new();
285
286 for (i, (&row_a, &col_a)) in coo_a.row_indices.iter().zip(coo_a.col_indices.iter()).enumerate() {
287 let val_a = coo_a.values.data()[i];
288 for (j, (&row_b, &col_b)) in coo_b.row_indices.iter().zip(coo_b.col_indices.iter()).enumerate() {
289 if col_a == row_b {
290 let val_b = coo_b.values.data()[j];
291 *result_map.entry((row_a, col_b)).or_insert(0.0) += val_a * val_b;
292 }
293 }
294 }
295
296 let mut row_indices = Vec::new();
298 let mut col_indices = Vec::new();
299 let mut values_data = Vec::new();
300
301 let mut entries: Vec<_> = result_map.into_iter().collect();
302 entries.sort_by_key(|&(pos, _)| pos);
303
304 for ((row, col), val) in entries {
305 row_indices.push(row);
306 col_indices.push(col);
307 values_data.push(val);
308 }
309
310 let values = DenseTensor::new(values_data.clone(), vec![values_data.len()]);
311 Ok(SparseTensor::COO(COOTensor::new(row_indices, col_indices, values, [rows_a, cols_b])))
312 }
313}
314
315#[cfg(feature = "tensor-sparse")]
316impl SparseTensorOps for SparseTensor {
317 fn nnz(&self) -> usize {
318 match self {
319 SparseTensor::COO(coo) => coo.nnz(),
320 SparseTensor::CSR(csr) => csr.nnz(),
321 }
322 }
323
324 fn coo(&self) -> COOView<'_> {
325 self.coo_view()
326 }
327
328 fn row_indices(&self) -> &[usize] {
329 match self {
330 SparseTensor::COO(coo) => coo.row_indices(),
331 SparseTensor::CSR(_) => &[],
332 }
333 }
334
335 fn col_indices(&self) -> &[usize] {
336 match self {
337 SparseTensor::COO(coo) => coo.col_indices(),
338 SparseTensor::CSR(csr) => csr.col_indices(),
339 }
340 }
341
342 fn values(&self) -> &DenseTensor {
343 match self {
344 SparseTensor::COO(coo) => coo.values(),
345 SparseTensor::CSR(csr) => csr.values(),
346 }
347 }
348}
349
350#[cfg(feature = "tensor-sparse")]
351impl TensorBase for SparseTensor {
352 fn shape(&self) -> &[usize] {
353 match self {
354 SparseTensor::COO(coo) => &coo.shape[..],
355 SparseTensor::CSR(csr) => &csr.shape[..],
356 }
357 }
358
359 fn dtype(&self) -> DType {
360 DType::F64
361 }
362
363 fn device(&self) -> Device {
364 Device::Cpu
365 }
366
367 fn to_dense(&self) -> DenseTensor {
368 let shape = self.shape();
369 let rows = shape[0];
370 let cols = shape[1];
371 let mut data = vec![0.0; rows * cols];
372 let coo = self.to_coo();
373
374 for (i, (&row, &col)) in coo.row_indices.iter().zip(coo.col_indices.iter()).enumerate() {
375 let val = coo.values.data()[i];
376 data[row * cols + col] = val;
377 }
378
379 DenseTensor::new(data, vec![rows, cols])
380 }
381
382 #[cfg(feature = "tensor-sparse")]
383 fn to_sparse(&self) -> Option<SparseTensor> {
384 Some(self.clone())
385 }
386}
387
388#[cfg(feature = "tensor-sparse")]
389impl fmt::Debug for SparseTensor {
390 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391 let shape = self.shape();
392 let rows = shape[0];
393 let cols = shape[1];
394 f.debug_struct("SparseTensor")
395 .field("shape", &[rows, cols])
396 .field("nnz", &self.nnz())
397 .field("sparsity", &self.sparsity())
398 .finish()
399 }
400}
401
402impl COOTensor {
404 pub fn shape_array(&self) -> [usize; 2] {
406 self.shape
407 }
408}
409
410impl CSRTensor {
412 pub fn shape_array(&self) -> [usize; 2] {
414 self.shape
415 }
416
417 pub fn row(&self, row: usize) -> Option<Vec<(usize, f64)>> {
419 if row >= self.shape[0] {
420 return None;
421 }
422
423 let start = self.row_offsets[row];
424 let end = self.row_offsets[row + 1];
425
426 if start == end {
427 return Some(Vec::new());
428 }
429
430 let mut result = Vec::with_capacity(end - start);
431 for i in start..end {
432 result.push((self.col_indices[i], self.values.data()[i]));
433 }
434 Some(result)
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_coo_creation() {
444 let edges = vec![
445 (0, 1, 1.0),
446 (0, 2, 2.0),
447 (1, 2, 3.0),
448 (2, 0, 4.0),
449 ];
450 let coo = SparseTensor::from_edges(&edges, [3, 3]);
451
452 assert_eq!(coo.nnz(), 4);
453 assert_eq!(coo.shape(), &[3, 3]);
454 }
455
456 #[test]
457 fn test_coo_to_csr() {
458 let edges = vec![
459 (0, 1, 1.0),
460 (0, 2, 2.0),
461 (1, 2, 3.0),
462 (2, 0, 4.0),
463 ];
464 let coo = SparseTensor::from_edges(&edges, [3, 3]);
465 let csr = coo.to_csr();
466
467 assert_eq!(csr.nnz(), 4);
468 assert_eq!(csr.row_offsets(), &[0, 2, 3, 4]);
469 }
470
471 #[test]
472 fn test_sparse_dense_conversion() {
473 let edges = vec![
474 (0, 1, 1.0),
475 (0, 2, 2.0),
476 (1, 2, 3.0),
477 (2, 0, 4.0),
478 ];
479 let sparse = SparseTensor::from_edges(&edges, [3, 3]);
480 let dense = sparse.to_dense();
481
482 assert_eq!(dense.shape(), &[3, 3]);
483 assert_eq!(dense.get(&[0, 1]).unwrap(), 1.0);
484 assert_eq!(dense.get(&[0, 2]).unwrap(), 2.0);
485 assert_eq!(dense.get(&[2, 0]).unwrap(), 4.0);
486 }
487
488 #[test]
489 fn test_spmv() {
490 let edges = vec![
491 (0, 0, 1.0),
492 (0, 1, 2.0),
493 (1, 0, 3.0),
494 (1, 1, 4.0),
495 ];
496 let sparse = SparseTensor::from_edges(&edges, [2, 2]);
497 let x = DenseTensor::new(vec![1.0, 2.0], vec![2]);
498
499 let result = sparse.spmv(&x).unwrap();
500 assert_eq!(result.data(), &[5.0, 11.0]);
502 }
503
504 #[test]
505 fn test_spmm() {
506 let edges_a = vec![
507 (0, 0, 1.0),
508 (0, 1, 2.0),
509 (1, 0, 3.0),
510 (1, 1, 4.0),
511 ];
512 let a = SparseTensor::from_edges(&edges_a, [2, 2]);
513
514 let edges_b = vec![
515 (0, 0, 5.0),
516 (0, 1, 6.0),
517 (1, 0, 7.0),
518 (1, 1, 8.0),
519 ];
520 let b = SparseTensor::from_edges(&edges_b, [2, 2]);
521
522 let result = a.spmm(&b).unwrap();
523 let result_dense = result.to_dense();
524
525 assert_eq!(result_dense.get(&[0, 0]).unwrap(), 19.0);
527 assert_eq!(result_dense.get(&[0, 1]).unwrap(), 22.0);
528 assert_eq!(result_dense.get(&[1, 0]).unwrap(), 43.0);
529 assert_eq!(result_dense.get(&[1, 1]).unwrap(), 50.0);
530 }
531}