1use super::SparseTensor;
4use crate::dtype::{DType, Element};
5use crate::error::{Error, Result};
6use crate::runtime::Runtime;
7use crate::sparse::SparseFormat;
8use crate::tensor::Tensor;
9
10impl<R: Runtime<DType = DType>> SparseTensor<R> {
11 pub fn to_coo(&self) -> Result<SparseTensor<R>> {
17 match self {
18 SparseTensor::Coo(d) => Ok(SparseTensor::Coo(d.clone())),
19 SparseTensor::Csr(d) => Ok(SparseTensor::Coo(d.to_coo()?)),
20 SparseTensor::Csc(d) => Ok(SparseTensor::Coo(d.to_coo()?)),
21 }
22 }
23
24 pub fn to_csr(&self) -> Result<SparseTensor<R>> {
26 match self {
27 SparseTensor::Coo(d) => Ok(SparseTensor::Csr(d.to_csr()?)),
28 SparseTensor::Csr(d) => Ok(SparseTensor::Csr(d.clone())),
29 SparseTensor::Csc(d) => Ok(SparseTensor::Csr(d.to_csr()?)),
30 }
31 }
32
33 pub fn to_csc(&self) -> Result<SparseTensor<R>> {
35 match self {
36 SparseTensor::Coo(d) => Ok(SparseTensor::Csc(d.to_csc()?)),
37 SparseTensor::Csr(d) => Ok(SparseTensor::Csc(d.to_csc()?)),
38 SparseTensor::Csc(d) => Ok(SparseTensor::Csc(d.clone())),
39 }
40 }
41
42 pub fn to_format(&self, format: SparseFormat) -> Result<SparseTensor<R>> {
44 match format {
45 SparseFormat::Coo => self.to_coo(),
46 SparseFormat::Csr => self.to_csr(),
47 SparseFormat::Csc => self.to_csc(),
48 }
49 }
50
51 pub fn to_dense(&self, device: &R::Device) -> Result<Tensor<R>> {
83 let [nrows, ncols] = self.shape();
84 let dtype = self.dtype();
85 let numel = nrows * ncols;
86
87 if self.is_empty() {
89 crate::dispatch_dtype!(dtype, T => {
90 let zeros: Vec<T> = vec![T::zero(); numel];
91 return Ok(Tensor::from_slice(&zeros, &[nrows, ncols], device));
92 }, "sparse to dense empty");
93 }
94
95 let coo = match self {
97 SparseTensor::Coo(d) => d.clone(),
98 SparseTensor::Csr(d) => d.to_coo()?,
99 SparseTensor::Csc(d) => d.to_coo()?,
100 };
101
102 let row_indices: Vec<i64> = coo.row_indices().to_vec();
104 let col_indices: Vec<i64> = coo.col_indices().to_vec();
105
106 crate::dispatch_dtype!(dtype, T => {
108 let values: Vec<T> = coo.values().to_vec();
109
110 let mut dense_data: Vec<T> = vec![T::zero(); numel];
112
113 for (i, (val, (row, col))) in values.iter()
115 .zip(row_indices.iter().zip(col_indices.iter()))
116 .enumerate()
117 {
118 let r = *row as usize;
119 let c = *col as usize;
120
121 if r >= nrows || c >= ncols {
122 return Err(Error::IndexOutOfBounds {
123 index: i,
124 size: numel,
125 });
126 }
127
128 let idx = r * ncols + c;
129 dense_data[idx] = *val;
130 }
131
132 return Ok(Tensor::from_slice(&dense_data, &[nrows, ncols], device));
133 }, "sparse to dense conversion");
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::dtype::DType;
141 use crate::runtime::Runtime;
142 use crate::runtime::cpu::{CpuClient, CpuRuntime};
143 use crate::sparse::SparseFormat;
144 use crate::tensor::Tensor;
145
146 #[test]
151 fn test_to_dense_coo() {
152 let device = <CpuRuntime as Runtime>::Device::default();
153
154 let sparse = SparseTensor::<CpuRuntime>::from_coo_slices(
159 &[0i64, 1, 2],
160 &[1i64, 0, 2],
161 &[5.0f32, 3.0, 7.0],
162 [3, 3],
163 &device,
164 )
165 .unwrap();
166
167 let dense = sparse.to_dense(&device).unwrap();
168
169 assert_eq!(dense.shape(), &[3, 3]);
170 assert_eq!(dense.dtype(), DType::F32);
171
172 let data: Vec<f32> = dense.to_vec();
173 assert_eq!(data, vec![0.0, 5.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 7.0]);
174 }
175
176 #[test]
177 fn test_to_dense_csr() {
178 let device = <CpuRuntime as Runtime>::Device::default();
179
180 let sparse = SparseTensor::<CpuRuntime>::from_csr_slices(
185 &[0i64, 2, 3, 5],
186 &[0i64, 2, 2, 0, 1],
187 &[1.0f32, 2.0, 3.0, 4.0, 5.0],
188 [3, 3],
189 &device,
190 )
191 .unwrap();
192
193 assert!(sparse.is_csr());
194
195 let dense = sparse.to_dense(&device).unwrap();
196
197 let data: Vec<f32> = dense.to_vec();
198 assert_eq!(data, vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0]);
199 }
200
201 #[test]
202 fn test_to_dense_csc() {
203 let device = <CpuRuntime as Runtime>::Device::default();
204
205 let sparse = SparseTensor::<CpuRuntime>::from_csc_slices(
210 &[0i64, 2, 3, 5],
211 &[0i64, 2, 2, 0, 1],
212 &[1.0f32, 4.0, 5.0, 2.0, 3.0],
213 [3, 3],
214 &device,
215 )
216 .unwrap();
217
218 assert!(sparse.is_csc());
219
220 let dense = sparse.to_dense(&device).unwrap();
221
222 let data: Vec<f32> = dense.to_vec();
223 assert_eq!(data, vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0]);
224 }
225
226 #[test]
227 fn test_to_dense_empty() {
228 let device = <CpuRuntime as Runtime>::Device::default();
229
230 let sparse =
231 SparseTensor::<CpuRuntime>::empty([3, 3], DType::F32, SparseFormat::Coo, &device);
232
233 let dense = sparse.to_dense(&device).unwrap();
234
235 assert_eq!(dense.shape(), &[3, 3]);
236 let data: Vec<f32> = dense.to_vec();
237 assert_eq!(data, vec![0.0; 9]);
238 }
239
240 #[test]
241 fn test_to_dense_single_element() {
242 let device = <CpuRuntime as Runtime>::Device::default();
243
244 let sparse = SparseTensor::<CpuRuntime>::from_coo_slices(
246 &[1i64],
247 &[1i64],
248 &[42.0f32],
249 [2, 2],
250 &device,
251 )
252 .unwrap();
253
254 let dense = sparse.to_dense(&device).unwrap();
255
256 let data: Vec<f32> = dense.to_vec();
257 assert_eq!(data, vec![0.0, 0.0, 0.0, 42.0]);
258 }
259
260 #[test]
261 fn test_to_dense_f64() {
262 let device = <CpuRuntime as Runtime>::Device::default();
263
264 let sparse = SparseTensor::<CpuRuntime>::from_coo_slices(
265 &[0i64, 1],
266 &[0i64, 1],
267 &[1.5f64, 2.5],
268 [2, 2],
269 &device,
270 )
271 .unwrap();
272
273 let dense = sparse.to_dense(&device).unwrap();
274
275 assert_eq!(dense.dtype(), DType::F64);
276 let data: Vec<f64> = dense.to_vec();
277 assert_eq!(data, vec![1.5, 0.0, 0.0, 2.5]);
278 }
279
280 #[test]
281 fn test_dense_sparse_roundtrip() {
282 let device = <CpuRuntime as Runtime>::Device::default();
283 let client = CpuClient::new(device.clone());
284
285 let original_data = vec![1.0f32, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0];
290 let original = Tensor::<CpuRuntime>::from_slice(&original_data, &[3, 3], &device);
291
292 let sparse = SparseTensor::from_dense(&client, &original, 1e-10).unwrap();
294 let recovered = sparse.to_dense(&device).unwrap();
295
296 let recovered_data: Vec<f32> = recovered.to_vec();
297 assert_eq!(recovered_data, original_data);
298 }
299
300 #[test]
301 fn test_csr_to_dense_to_sparse_roundtrip() {
302 let device = <CpuRuntime as Runtime>::Device::default();
303 let client = CpuClient::new(device.clone());
304
305 let sparse_csr = SparseTensor::<CpuRuntime>::from_csr_slices(
307 &[0i64, 2, 3, 5],
308 &[0i64, 2, 2, 0, 1],
309 &[1.0f32, 2.0, 3.0, 4.0, 5.0],
310 [3, 3],
311 &device,
312 )
313 .unwrap();
314
315 let dense = sparse_csr.to_dense(&device).unwrap();
317 let sparse_coo = SparseTensor::from_dense(&client, &dense, 1e-10).unwrap();
318
319 assert!(sparse_coo.is_coo());
320 assert_eq!(sparse_coo.nnz(), 5);
321
322 let recovered = sparse_coo.to_dense(&device).unwrap();
324 let data: Vec<f32> = recovered.to_vec();
325 assert_eq!(data, vec![1.0, 0.0, 2.0, 0.0, 0.0, 3.0, 4.0, 5.0, 0.0]);
326 }
327}