1use crate::dtype::DType;
4use crate::error::{Error, Result};
5use crate::runtime::Runtime;
6use crate::sparse::{SparseOps, SparseTensor};
7
8impl<R: Runtime<DType = DType>> SparseTensor<R> {
9 pub fn mul(&self, other: &SparseTensor<R>) -> Result<SparseTensor<R>>
52 where
53 R::Client: SparseOps<R>,
54 {
55 if self.shape() != other.shape() {
57 return Err(Error::ShapeMismatch {
58 expected: vec![self.shape()[0], self.shape()[1]],
59 got: vec![other.shape()[0], other.shape()[1]],
60 });
61 }
62
63 if self.dtype() != other.dtype() {
65 return Err(Error::DTypeMismatch {
66 lhs: self.dtype(),
67 rhs: other.dtype(),
68 });
69 }
70
71 match (self, other) {
73 (SparseTensor::Coo(a), SparseTensor::Coo(b)) => Ok(SparseTensor::Coo(a.mul(b)?)),
74 (SparseTensor::Csr(a), SparseTensor::Csr(b)) => Ok(SparseTensor::Csr(a.mul(b)?)),
75 (SparseTensor::Csc(a), SparseTensor::Csc(b)) => Ok(SparseTensor::Csc(a.mul(b)?)),
76 _ => {
78 let coo_a = self.to_coo()?;
79 let coo_b = other.to_coo()?;
80 let coo_a_data = coo_a.as_coo().unwrap();
81 let coo_b_data = coo_b.as_coo().unwrap();
82 Ok(SparseTensor::Coo(coo_a_data.mul(coo_b_data)?))
83 }
84 }
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::dtype::DType;
92 use crate::runtime::Runtime;
93 use crate::runtime::cpu::{CpuClient, CpuRuntime};
94 use crate::sparse::SparseFormat;
95 use crate::tensor::Tensor;
96
97 #[test]
98 fn test_mul_coo_coo() {
99 let device = <CpuRuntime as Runtime>::Device::default();
100
101 let a = SparseTensor::<CpuRuntime>::from_coo_slices(
105 &[0i64, 0, 1],
106 &[0i64, 1, 1],
107 &[2.0f32, 3.0, 5.0],
108 [2, 2],
109 &device,
110 )
111 .unwrap();
112
113 let b = SparseTensor::<CpuRuntime>::from_coo_slices(
114 &[0i64, 1, 1],
115 &[0i64, 0, 1],
116 &[4.0f32, 6.0, 7.0],
117 [2, 2],
118 &device,
119 )
120 .unwrap();
121
122 let c = a.mul(&b).unwrap();
123
124 assert!(c.is_coo());
125 assert_eq!(c.nnz(), 2);
126
127 let dense = c.to_dense(&device).unwrap();
128 let data: Vec<f32> = dense.to_vec();
129 assert_eq!(data, vec![8.0, 0.0, 0.0, 35.0]);
130 }
131
132 #[test]
133 fn test_mul_csr_csr() {
134 let device = <CpuRuntime as Runtime>::Device::default();
135
136 let a = SparseTensor::<CpuRuntime>::from_csr_slices(
140 &[0i64, 2, 3],
141 &[0i64, 1, 1],
142 &[2.0f32, 3.0, 5.0],
143 [2, 2],
144 &device,
145 )
146 .unwrap();
147
148 let b = SparseTensor::<CpuRuntime>::from_csr_slices(
149 &[0i64, 1, 3],
150 &[0i64, 0, 1],
151 &[4.0f32, 6.0, 7.0],
152 [2, 2],
153 &device,
154 )
155 .unwrap();
156
157 let c = a.mul(&b).unwrap();
158
159 assert!(c.is_csr());
160 assert_eq!(c.nnz(), 2);
161
162 let dense = c.to_dense(&device).unwrap();
163 let data: Vec<f32> = dense.to_vec();
164 assert_eq!(data, vec![8.0, 0.0, 0.0, 35.0]);
165 }
166
167 #[test]
168 fn test_mul_csc_csc() {
169 let device = <CpuRuntime as Runtime>::Device::default();
170
171 let a = SparseTensor::<CpuRuntime>::from_csc_slices(
177 &[0i64, 1, 3],
178 &[0i64, 0, 1],
179 &[2.0f32, 3.0, 5.0],
180 [2, 2],
181 &device,
182 )
183 .unwrap();
184
185 let b = SparseTensor::<CpuRuntime>::from_csc_slices(
186 &[0i64, 2, 3],
187 &[0i64, 1, 1],
188 &[4.0f32, 6.0, 7.0],
189 [2, 2],
190 &device,
191 )
192 .unwrap();
193
194 let c = a.mul(&b).unwrap();
195
196 assert!(c.is_csc());
197 assert_eq!(c.nnz(), 2);
198
199 let dense = c.to_dense(&device).unwrap();
200 let data: Vec<f32> = dense.to_vec();
201 assert_eq!(data, vec![8.0, 0.0, 0.0, 35.0]);
202 }
203
204 #[test]
205 fn test_mul_mixed_formats() {
206 let device = <CpuRuntime as Runtime>::Device::default();
207
208 let a = SparseTensor::<CpuRuntime>::from_coo_slices(
212 &[0i64, 0, 1],
213 &[0i64, 1, 1],
214 &[2.0f32, 3.0, 5.0],
215 [2, 2],
216 &device,
217 )
218 .unwrap();
219
220 let b = SparseTensor::<CpuRuntime>::from_csr_slices(
221 &[0i64, 1, 3],
222 &[0i64, 0, 1],
223 &[4.0f32, 6.0, 7.0],
224 [2, 2],
225 &device,
226 )
227 .unwrap();
228
229 assert!(a.is_coo());
230 assert!(b.is_csr());
231
232 let c = a.mul(&b).unwrap();
233
234 assert!(c.is_coo());
236 assert_eq!(c.nnz(), 2);
237
238 let dense = c.to_dense(&device).unwrap();
239 let data: Vec<f32> = dense.to_vec();
240 assert_eq!(data, vec![8.0, 0.0, 0.0, 35.0]);
241 }
242
243 #[test]
244 fn test_mul_disjoint() {
245 let device = <CpuRuntime as Runtime>::Device::default();
246
247 let a = SparseTensor::<CpuRuntime>::from_csr_slices(
252 &[0i64, 1, 2],
253 &[0i64, 1],
254 &[1.0f32, 3.0],
255 [2, 2],
256 &device,
257 )
258 .unwrap();
259
260 let b = SparseTensor::<CpuRuntime>::from_csr_slices(
261 &[0i64, 1, 2],
262 &[1i64, 0],
263 &[2.0f32, 4.0],
264 [2, 2],
265 &device,
266 )
267 .unwrap();
268
269 let c = a.mul(&b).unwrap();
270
271 assert_eq!(c.nnz(), 0);
273
274 let dense = c.to_dense(&device).unwrap();
275 let data: Vec<f32> = dense.to_vec();
276 assert_eq!(data, vec![0.0, 0.0, 0.0, 0.0]);
277 }
278
279 #[test]
280 fn test_mul_shape_mismatch() {
281 let device = <CpuRuntime as Runtime>::Device::default();
282
283 let a = SparseTensor::<CpuRuntime>::empty([2, 3], DType::F32, SparseFormat::Csr, &device);
284 let b = SparseTensor::<CpuRuntime>::empty([3, 2], DType::F32, SparseFormat::Csr, &device);
285
286 let result = a.mul(&b);
287 assert!(result.is_err());
288 }
289
290 #[test]
291 fn test_mul_dtype_mismatch() {
292 let device = <CpuRuntime as Runtime>::Device::default();
293
294 let a = SparseTensor::<CpuRuntime>::empty([2, 2], DType::F32, SparseFormat::Csr, &device);
295 let b = SparseTensor::<CpuRuntime>::empty([2, 2], DType::F64, SparseFormat::Csr, &device);
296
297 let result = a.mul(&b);
298 assert!(result.is_err());
299 }
300
301 #[test]
302 fn test_mul_from_dense() {
303 let device = <CpuRuntime as Runtime>::Device::default();
304 let client = CpuClient::new(device.clone());
305
306 let dense_a = Tensor::<CpuRuntime>::from_slice(&[2.0f32, 3.0, 0.0, 5.0], &[2, 2], &device);
308 let dense_b = Tensor::<CpuRuntime>::from_slice(&[4.0f32, 0.0, 6.0, 7.0], &[2, 2], &device);
309
310 let a = SparseTensor::from_dense(&client, &dense_a, 1e-10).unwrap();
311 let b = SparseTensor::from_dense(&client, &dense_b, 1e-10).unwrap();
312
313 let c = a.mul(&b).unwrap();
314
315 let dense_c = c.to_dense(&device).unwrap();
316 let data: Vec<f32> = dense_c.to_vec();
317 assert_eq!(data, vec![8.0, 0.0, 0.0, 35.0]);
318 }
319
320 #[test]
321 fn test_mul_self() {
322 let device = <CpuRuntime as Runtime>::Device::default();
323
324 let a = SparseTensor::<CpuRuntime>::from_csr_slices(
326 &[0i64, 1, 2],
327 &[0i64, 1],
328 &[3.0f32, 4.0],
329 [2, 2],
330 &device,
331 )
332 .unwrap();
333
334 let c = a.mul(&a).unwrap();
335
336 let dense = c.to_dense(&device).unwrap();
337 let data: Vec<f32> = dense.to_vec();
338 assert_eq!(data, vec![9.0, 0.0, 0.0, 16.0]);
339 }
340
341 #[test]
342 fn test_mul_identity_sparse() {
343 let device = <CpuRuntime as Runtime>::Device::default();
344
345 let a = SparseTensor::<CpuRuntime>::from_csr_slices(
347 &[0i64, 2, 3],
348 &[0i64, 1, 1],
349 &[2.0f32, 3.0, 5.0],
350 [2, 2],
351 &device,
352 )
353 .unwrap();
354
355 let ones = SparseTensor::<CpuRuntime>::from_csr_slices(
356 &[0i64, 2, 3],
357 &[0i64, 1, 1],
358 &[1.0f32, 1.0, 1.0],
359 [2, 2],
360 &device,
361 )
362 .unwrap();
363
364 let c = a.mul(&ones).unwrap();
365
366 assert_eq!(c.nnz(), 3);
367
368 let dense = c.to_dense(&device).unwrap();
369 let data: Vec<f32> = dense.to_vec();
370 assert_eq!(data, vec![2.0, 3.0, 0.0, 5.0]);
371 }
372}