Skip to main content

numr/sparse/tensor/elementwise/
mul.rs

1//! Element-wise multiplication operation for sparse tensors
2
3use crate::dtype::DType;
4use crate::error::{Error, Result};
5use crate::runtime::Runtime;
6use crate::sparse::{SparseOps, SparseTensor};
7
8impl<R: Runtime<DType = DType>> SparseTensor<R> {
9    /// Element-wise multiplication (Hadamard product): C = A .* B
10    ///
11    /// Computes the element-wise product of two sparse tensors with the same shape.
12    /// Only positions where BOTH tensors have non-zero values will be non-zero
13    /// in the result.
14    ///
15    /// # Arguments
16    ///
17    /// * `other` - Another sparse tensor with the same shape and dtype
18    ///
19    /// # Returns
20    ///
21    /// A new sparse tensor containing the element-wise product
22    ///
23    /// # Errors
24    ///
25    /// Returns error if:
26    /// - Shapes don't match
27    /// - Dtypes don't match
28    ///
29    /// # Format Handling
30    ///
31    /// - Same format: Uses native mul implementation
32    /// - Different formats: Converts to COO, multiplies, returns COO
33    ///
34    /// # Example
35    ///
36    /// ```
37    /// # use numr::prelude::*;
38    /// # #[cfg(feature = "sparse")]
39    /// # {
40    /// # use numr::sparse::SparseTensor;
41    /// # let device = CpuDevice::new();
42    /// // A:          B:          C = A .* B:
43    /// // [2, 3]      [4, 0]      [8, 0]
44    /// // [0, 5]  .*  [6, 7]  =   [0, 35]
45    /// # let a = SparseTensor::<CpuRuntime>::from_coo_slices(&[0, 0, 1], &[0, 1, 1], &[2.0f32, 3.0, 5.0], [2, 2], &device)?;
46    /// # let b = SparseTensor::<CpuRuntime>::from_coo_slices(&[0, 1], &[0, 1], &[4.0f32, 7.0], [2, 2], &device)?;
47    /// let c = a.mul(&b)?;
48    /// # }
49    /// # Ok::<(), numr::error::Error>(())
50    /// ```
51    pub fn mul(&self, other: &SparseTensor<R>) -> Result<SparseTensor<R>>
52    where
53        R::Client: SparseOps<R>,
54    {
55        // Validate shapes match
56        if self.shape() != other.shape() {
57            return Err(Error::ShapeMismatch {
58                expected: vec![self.shape()[0], self.shape()[1]],
59                got: vec![other.shape()[0], other.shape()[1]],
60            });
61        }
62
63        // Validate dtypes match
64        if self.dtype() != other.dtype() {
65            return Err(Error::DTypeMismatch {
66                lhs: self.dtype(),
67                rhs: other.dtype(),
68            });
69        }
70
71        // If same format, use native mul
72        match (self, other) {
73            (SparseTensor::Coo(a), SparseTensor::Coo(b)) => Ok(SparseTensor::Coo(a.mul(b)?)),
74            (SparseTensor::Csr(a), SparseTensor::Csr(b)) => Ok(SparseTensor::Csr(a.mul(b)?)),
75            (SparseTensor::Csc(a), SparseTensor::Csc(b)) => Ok(SparseTensor::Csc(a.mul(b)?)),
76            // Different formats: convert to COO and multiply
77            _ => {
78                let coo_a = self.to_coo()?;
79                let coo_b = other.to_coo()?;
80                let coo_a_data = coo_a.as_coo().unwrap();
81                let coo_b_data = coo_b.as_coo().unwrap();
82                Ok(SparseTensor::Coo(coo_a_data.mul(coo_b_data)?))
83            }
84        }
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use crate::dtype::DType;
92    use crate::runtime::Runtime;
93    use crate::runtime::cpu::{CpuClient, CpuRuntime};
94    use crate::sparse::SparseFormat;
95    use crate::tensor::Tensor;
96
97    #[test]
98    fn test_mul_coo_coo() {
99        let device = <CpuRuntime as Runtime>::Device::default();
100
101        // A:         B:
102        // [2, 3]     [4, 0]
103        // [0, 5]     [6, 7]
104        let a = SparseTensor::<CpuRuntime>::from_coo_slices(
105            &[0i64, 0, 1],
106            &[0i64, 1, 1],
107            &[2.0f32, 3.0, 5.0],
108            [2, 2],
109            &device,
110        )
111        .unwrap();
112
113        let b = SparseTensor::<CpuRuntime>::from_coo_slices(
114            &[0i64, 1, 1],
115            &[0i64, 0, 1],
116            &[4.0f32, 6.0, 7.0],
117            [2, 2],
118            &device,
119        )
120        .unwrap();
121
122        let c = a.mul(&b).unwrap();
123
124        assert!(c.is_coo());
125        assert_eq!(c.nnz(), 2);
126
127        let dense = c.to_dense(&device).unwrap();
128        let data: Vec<f32> = dense.to_vec();
129        assert_eq!(data, vec![8.0, 0.0, 0.0, 35.0]);
130    }
131
132    #[test]
133    fn test_mul_csr_csr() {
134        let device = <CpuRuntime as Runtime>::Device::default();
135
136        // A:         B:
137        // [2, 3]     [4, 0]
138        // [0, 5]     [6, 7]
139        let a = SparseTensor::<CpuRuntime>::from_csr_slices(
140            &[0i64, 2, 3],
141            &[0i64, 1, 1],
142            &[2.0f32, 3.0, 5.0],
143            [2, 2],
144            &device,
145        )
146        .unwrap();
147
148        let b = SparseTensor::<CpuRuntime>::from_csr_slices(
149            &[0i64, 1, 3],
150            &[0i64, 0, 1],
151            &[4.0f32, 6.0, 7.0],
152            [2, 2],
153            &device,
154        )
155        .unwrap();
156
157        let c = a.mul(&b).unwrap();
158
159        assert!(c.is_csr());
160        assert_eq!(c.nnz(), 2);
161
162        let dense = c.to_dense(&device).unwrap();
163        let data: Vec<f32> = dense.to_vec();
164        assert_eq!(data, vec![8.0, 0.0, 0.0, 35.0]);
165    }
166
167    #[test]
168    fn test_mul_csc_csc() {
169        let device = <CpuRuntime as Runtime>::Device::default();
170
171        // A:         B:
172        // [2, 3]     [4, 0]
173        // [0, 5]     [6, 7]
174        // CSC for A: col_ptrs=[0,1,3], row_indices=[0,0,1], values=[2,3,5]
175        // CSC for B: col_ptrs=[0,2,3], row_indices=[0,1,1], values=[4,6,7]
176        let a = SparseTensor::<CpuRuntime>::from_csc_slices(
177            &[0i64, 1, 3],
178            &[0i64, 0, 1],
179            &[2.0f32, 3.0, 5.0],
180            [2, 2],
181            &device,
182        )
183        .unwrap();
184
185        let b = SparseTensor::<CpuRuntime>::from_csc_slices(
186            &[0i64, 2, 3],
187            &[0i64, 1, 1],
188            &[4.0f32, 6.0, 7.0],
189            [2, 2],
190            &device,
191        )
192        .unwrap();
193
194        let c = a.mul(&b).unwrap();
195
196        assert!(c.is_csc());
197        assert_eq!(c.nnz(), 2);
198
199        let dense = c.to_dense(&device).unwrap();
200        let data: Vec<f32> = dense.to_vec();
201        assert_eq!(data, vec![8.0, 0.0, 0.0, 35.0]);
202    }
203
204    #[test]
205    fn test_mul_mixed_formats() {
206        let device = <CpuRuntime as Runtime>::Device::default();
207
208        // A (COO):   B (CSR):
209        // [2, 3]     [4, 0]
210        // [0, 5]     [6, 7]
211        let a = SparseTensor::<CpuRuntime>::from_coo_slices(
212            &[0i64, 0, 1],
213            &[0i64, 1, 1],
214            &[2.0f32, 3.0, 5.0],
215            [2, 2],
216            &device,
217        )
218        .unwrap();
219
220        let b = SparseTensor::<CpuRuntime>::from_csr_slices(
221            &[0i64, 1, 3],
222            &[0i64, 0, 1],
223            &[4.0f32, 6.0, 7.0],
224            [2, 2],
225            &device,
226        )
227        .unwrap();
228
229        assert!(a.is_coo());
230        assert!(b.is_csr());
231
232        let c = a.mul(&b).unwrap();
233
234        // Mixed formats convert to COO
235        assert!(c.is_coo());
236        assert_eq!(c.nnz(), 2);
237
238        let dense = c.to_dense(&device).unwrap();
239        let data: Vec<f32> = dense.to_vec();
240        assert_eq!(data, vec![8.0, 0.0, 0.0, 35.0]);
241    }
242
243    #[test]
244    fn test_mul_disjoint() {
245        let device = <CpuRuntime as Runtime>::Device::default();
246
247        // A:         B:
248        // [1, 0]     [0, 2]
249        // [0, 3]     [4, 0]
250        // Completely disjoint positions
251        let a = SparseTensor::<CpuRuntime>::from_csr_slices(
252            &[0i64, 1, 2],
253            &[0i64, 1],
254            &[1.0f32, 3.0],
255            [2, 2],
256            &device,
257        )
258        .unwrap();
259
260        let b = SparseTensor::<CpuRuntime>::from_csr_slices(
261            &[0i64, 1, 2],
262            &[1i64, 0],
263            &[2.0f32, 4.0],
264            [2, 2],
265            &device,
266        )
267        .unwrap();
268
269        let c = a.mul(&b).unwrap();
270
271        // Result is empty since no positions overlap
272        assert_eq!(c.nnz(), 0);
273
274        let dense = c.to_dense(&device).unwrap();
275        let data: Vec<f32> = dense.to_vec();
276        assert_eq!(data, vec![0.0, 0.0, 0.0, 0.0]);
277    }
278
279    #[test]
280    fn test_mul_shape_mismatch() {
281        let device = <CpuRuntime as Runtime>::Device::default();
282
283        let a = SparseTensor::<CpuRuntime>::empty([2, 3], DType::F32, SparseFormat::Csr, &device);
284        let b = SparseTensor::<CpuRuntime>::empty([3, 2], DType::F32, SparseFormat::Csr, &device);
285
286        let result = a.mul(&b);
287        assert!(result.is_err());
288    }
289
290    #[test]
291    fn test_mul_dtype_mismatch() {
292        let device = <CpuRuntime as Runtime>::Device::default();
293
294        let a = SparseTensor::<CpuRuntime>::empty([2, 2], DType::F32, SparseFormat::Csr, &device);
295        let b = SparseTensor::<CpuRuntime>::empty([2, 2], DType::F64, SparseFormat::Csr, &device);
296
297        let result = a.mul(&b);
298        assert!(result.is_err());
299    }
300
301    #[test]
302    fn test_mul_from_dense() {
303        let device = <CpuRuntime as Runtime>::Device::default();
304        let client = CpuClient::new(device.clone());
305
306        // Create sparse matrices from dense
307        let dense_a = Tensor::<CpuRuntime>::from_slice(&[2.0f32, 3.0, 0.0, 5.0], &[2, 2], &device);
308        let dense_b = Tensor::<CpuRuntime>::from_slice(&[4.0f32, 0.0, 6.0, 7.0], &[2, 2], &device);
309
310        let a = SparseTensor::from_dense(&client, &dense_a, 1e-10).unwrap();
311        let b = SparseTensor::from_dense(&client, &dense_b, 1e-10).unwrap();
312
313        let c = a.mul(&b).unwrap();
314
315        let dense_c = c.to_dense(&device).unwrap();
316        let data: Vec<f32> = dense_c.to_vec();
317        assert_eq!(data, vec![8.0, 0.0, 0.0, 35.0]);
318    }
319
320    #[test]
321    fn test_mul_self() {
322        let device = <CpuRuntime as Runtime>::Device::default();
323
324        // A .* A = A^2 (element-wise)
325        let a = SparseTensor::<CpuRuntime>::from_csr_slices(
326            &[0i64, 1, 2],
327            &[0i64, 1],
328            &[3.0f32, 4.0],
329            [2, 2],
330            &device,
331        )
332        .unwrap();
333
334        let c = a.mul(&a).unwrap();
335
336        let dense = c.to_dense(&device).unwrap();
337        let data: Vec<f32> = dense.to_vec();
338        assert_eq!(data, vec![9.0, 0.0, 0.0, 16.0]);
339    }
340
341    #[test]
342    fn test_mul_identity_sparse() {
343        let device = <CpuRuntime as Runtime>::Device::default();
344
345        // Multiplying by a sparse "all ones" matrix at same positions = same matrix
346        let a = SparseTensor::<CpuRuntime>::from_csr_slices(
347            &[0i64, 2, 3],
348            &[0i64, 1, 1],
349            &[2.0f32, 3.0, 5.0],
350            [2, 2],
351            &device,
352        )
353        .unwrap();
354
355        let ones = SparseTensor::<CpuRuntime>::from_csr_slices(
356            &[0i64, 2, 3],
357            &[0i64, 1, 1],
358            &[1.0f32, 1.0, 1.0],
359            [2, 2],
360            &device,
361        )
362        .unwrap();
363
364        let c = a.mul(&ones).unwrap();
365
366        assert_eq!(c.nnz(), 3);
367
368        let dense = c.to_dense(&device).unwrap();
369        let data: Vec<f32> = dense.to_vec();
370        assert_eq!(data, vec![2.0, 3.0, 0.0, 5.0]);
371    }
372}