1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
//! Sparse tensor operations.
//!
//! This module contains functions for working with sparse tensors.
use super::Tensor;
use crate::errors::{Result, TrustformersError};
use crate::sparse_tensor::SparseTensor;
impl Tensor {
/// Convert a dense tensor to sparse format.
///
/// # Arguments
///
/// * `threshold` - Values below this threshold will be considered zero
///
/// # Returns
///
/// A sparse tensor representation.
pub fn to_sparse(&self, threshold: f32) -> Result<Tensor> {
match self {
Tensor::F32(a) => {
let sparse = SparseTensor::from_dense(&Tensor::F32(a.clone()), threshold)?;
Ok(Tensor::Sparse(sparse))
},
Tensor::Sparse(_) => {
// Already sparse
Ok(self.clone())
},
_ => Err(TrustformersError::tensor_op_error(
"Cannot convert this tensor type to sparse",
"Tensor::to_sparse",
)),
}
}
/// Convert a sparse tensor to dense format.
///
/// # Returns
///
/// A dense tensor representation.
pub fn to_dense(&self) -> Result<Tensor> {
match self {
Tensor::Sparse(s) => s.to_dense(),
Tensor::F32(_) | Tensor::F64(_) | Tensor::I64(_) => {
// Already dense
Ok(self.clone())
},
_ => Err(TrustformersError::tensor_op_error(
"Cannot convert this tensor type to dense",
"Tensor::to_dense",
)),
}
}
/// Check if the tensor is sparse.
///
/// # Returns
///
/// True if the tensor is sparse, false otherwise.
pub fn is_sparse(&self) -> bool {
matches!(self, Tensor::Sparse(_))
}
/// Get the sparsity ratio of the tensor.
///
/// # Returns
///
/// The ratio of zero elements to total elements.
pub fn sparsity(&self) -> Result<f32> {
match self {
Tensor::Sparse(s) => Ok(s.sparsity()),
Tensor::F32(a) => {
let total = a.len() as f32;
let zeros = a.iter().filter(|&&x| x == 0.0).count() as f32;
Ok(zeros / total)
},
_ => Err(TrustformersError::tensor_op_error(
"Sparsity calculation not supported for this tensor type",
"Tensor::sparsity",
)),
}
}
/// Get the number of non-zero elements.
///
/// # Returns
///
/// The number of non-zero elements.
pub fn nnz(&self) -> Result<usize> {
match self {
Tensor::Sparse(s) => Ok(s.nnz()),
Tensor::F32(a) => Ok(a.iter().filter(|&&x| x != 0.0).count()),
_ => Err(TrustformersError::tensor_op_error(
"NNZ calculation not supported for this tensor type",
"Tensor::nnz",
)),
}
}
/// Create a sparse tensor in COO format.
///
/// # Arguments
///
/// * `indices` - Coordinate indices
/// * `values` - Non-zero values
/// * `shape` - Tensor shape
///
/// # Returns
///
/// A sparse tensor in COO format.
pub fn sparse_coo(
indices: Vec<Vec<usize>>,
values: Vec<f32>,
shape: Vec<usize>,
) -> Result<Tensor> {
let sparse = SparseTensor::new_coo(shape, indices[0].clone(), indices[1].clone(), values)?;
Ok(Tensor::Sparse(sparse))
}
/// Create a sparse tensor in CSR format.
///
/// # Arguments
///
/// * `row_ptr` - Row pointers
/// * `col_indices` - Column indices
/// * `values` - Non-zero values
/// * `shape` - Tensor shape
///
/// # Returns
///
/// A sparse tensor in CSR format.
pub fn sparse_csr(
row_ptr: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<f32>,
shape: Vec<usize>,
) -> Result<Tensor> {
let sparse = SparseTensor::new_csr(shape, row_ptr, col_indices, values)?;
Ok(Tensor::Sparse(sparse))
}
}