Skip to main content

graphblas_sparse_linear_algebra/operators/element_wise_addition/
element_wise_matrix_addition.rs

1use crate::collections::sparse_matrix::GetGraphblasSparseMatrix;
2use crate::context::CallGraphBlasContext;
3use crate::error::SparseLinearAlgebraError;
4use crate::operators::binary_operator::AccumulatorBinaryOperator;
5use crate::operators::mask::MatrixMask;
6use crate::operators::options::GetOptionsForOperatorWithMatrixArguments;
7use crate::operators::{binary_operator::BinaryOperator, monoid::Monoid, semiring::Semiring};
8use crate::value_type::ValueType;
9
10use crate::graphblas_bindings::{
11    GrB_Matrix_eWiseAdd_BinaryOp, GrB_Matrix_eWiseAdd_Monoid, GrB_Matrix_eWiseAdd_Semiring,
12};
13
14// Implemented methods do not provide mutable access to GraphBLAS operators or options.
15// Code review must consider that no mtable access is provided.
16// https://doc.rust-lang.org/nomicon/send-and-sync.html
17unsafe impl Sync for ElementWiseMatrixAdditionSemiringOperator {}
18unsafe impl Send for ElementWiseMatrixAdditionSemiringOperator {}
19
20#[derive(Debug, Clone)]
21pub struct ElementWiseMatrixAdditionSemiringOperator {}
22
23impl ElementWiseMatrixAdditionSemiringOperator {
24    pub fn new() -> Self {
25        Self {}
26    }
27}
28
29pub trait ApplyElementWiseMatrixAdditionSemiring<EvaluationDomain: ValueType> {
30    fn apply(
31        &self,
32        multiplier: &impl GetGraphblasSparseMatrix,
33        operator: &impl Semiring<EvaluationDomain>,
34        multiplicant: &impl GetGraphblasSparseMatrix,
35        accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
36        product: &mut impl GetGraphblasSparseMatrix,
37        mask: &impl MatrixMask,
38        options: &impl GetOptionsForOperatorWithMatrixArguments,
39    ) -> Result<(), SparseLinearAlgebraError>;
40}
41
42impl<EvaluationDomain: ValueType> ApplyElementWiseMatrixAdditionSemiring<EvaluationDomain>
43    for ElementWiseMatrixAdditionSemiringOperator
44{
45    fn apply(
46        &self,
47        multiplier: &impl GetGraphblasSparseMatrix,
48        operator: &impl Semiring<EvaluationDomain>,
49        multiplicant: &impl GetGraphblasSparseMatrix,
50        accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
51        product: &mut impl GetGraphblasSparseMatrix,
52        mask: &impl MatrixMask,
53        options: &impl GetOptionsForOperatorWithMatrixArguments,
54    ) -> Result<(), SparseLinearAlgebraError> {
55        let context = product.context_ref();
56
57        context.call(
58            || unsafe {
59                GrB_Matrix_eWiseAdd_Semiring(
60                    product.graphblas_matrix_ptr(),
61                    mask.graphblas_matrix_ptr(),
62                    accumulator.accumulator_graphblas_type(),
63                    operator.graphblas_type(),
64                    multiplier.graphblas_matrix_ptr(),
65                    multiplicant.graphblas_matrix_ptr(),
66                    options.graphblas_descriptor(),
67                )
68            },
69            unsafe { &product.graphblas_matrix_ptr() },
70        )?;
71        Ok(())
72    }
73}
74
75// Implemented methods do not provide mutable access to GraphBLAS operators or options.
76// Code review must consider that no mtable access is provided.
77// https://doc.rust-lang.org/nomicon/send-and-sync.html
78unsafe impl Sync for ElementWiseMatrixAdditionMonoidOperator {}
79unsafe impl Send for ElementWiseMatrixAdditionMonoidOperator {}
80
81#[derive(Debug, Clone)]
82pub struct ElementWiseMatrixAdditionMonoidOperator {}
83
84impl ElementWiseMatrixAdditionMonoidOperator {
85    pub fn new() -> Self {
86        Self {}
87    }
88}
89
90pub trait ApplyElementWiseMatrixAdditionMonoidOperator<EvaluationDomain: ValueType> {
91    fn apply(
92        &self,
93        multiplier: &impl GetGraphblasSparseMatrix,
94        operator: &impl Monoid<EvaluationDomain>,
95        multiplicant: &impl GetGraphblasSparseMatrix,
96        accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
97        product: &mut impl GetGraphblasSparseMatrix,
98        mask: &impl MatrixMask,
99        options: &impl GetOptionsForOperatorWithMatrixArguments,
100    ) -> Result<(), SparseLinearAlgebraError>;
101}
102
103impl<EvaluationDomain: ValueType> ApplyElementWiseMatrixAdditionMonoidOperator<EvaluationDomain>
104    for ElementWiseMatrixAdditionMonoidOperator
105{
106    fn apply(
107        &self,
108        multiplier: &impl GetGraphblasSparseMatrix,
109        operator: &impl Monoid<EvaluationDomain>,
110        multiplicant: &impl GetGraphblasSparseMatrix,
111        accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
112        product: &mut impl GetGraphblasSparseMatrix,
113        mask: &impl MatrixMask,
114        options: &impl GetOptionsForOperatorWithMatrixArguments,
115    ) -> Result<(), SparseLinearAlgebraError> {
116        let context = product.context_ref();
117
118        context.call(
119            || unsafe {
120                GrB_Matrix_eWiseAdd_Monoid(
121                    product.graphblas_matrix_ptr(),
122                    mask.graphblas_matrix_ptr(),
123                    accumulator.accumulator_graphblas_type(),
124                    operator.graphblas_type(),
125                    multiplier.graphblas_matrix_ptr(),
126                    multiplicant.graphblas_matrix_ptr(),
127                    options.graphblas_descriptor(),
128                )
129            },
130            unsafe { &product.graphblas_matrix_ptr() },
131        )?;
132
133        Ok(())
134    }
135}
136
137// Implemented methods do not provide mutable access to GraphBLAS operators or options.
138// Code review must consider that no mtable access is provided.
139// https://doc.rust-lang.org/nomicon/send-and-sync.html
140unsafe impl Sync for ElementWiseMatrixAdditionBinaryOperator {}
141unsafe impl Send for ElementWiseMatrixAdditionBinaryOperator {}
142
143#[derive(Debug, Clone)]
144pub struct ElementWiseMatrixAdditionBinaryOperator {}
145
146impl ElementWiseMatrixAdditionBinaryOperator {
147    pub fn new() -> Self {
148        Self {}
149    }
150}
151
152pub trait ApplyElementWiseMatrixAdditionBinaryOperator<EvaluationDomain: ValueType> {
153    fn apply(
154        &self,
155        multiplier: &impl GetGraphblasSparseMatrix,
156        operator: &impl BinaryOperator<EvaluationDomain>,
157        multiplicant: &impl GetGraphblasSparseMatrix,
158        accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
159        product: &mut impl GetGraphblasSparseMatrix,
160        mask: &impl MatrixMask,
161        options: &impl GetOptionsForOperatorWithMatrixArguments,
162    ) -> Result<(), SparseLinearAlgebraError>;
163}
164
165impl<EvaluationDomain: ValueType> ApplyElementWiseMatrixAdditionBinaryOperator<EvaluationDomain>
166    for ElementWiseMatrixAdditionBinaryOperator
167{
168    fn apply(
169        &self,
170        multiplier: &impl GetGraphblasSparseMatrix,
171        operator: &impl BinaryOperator<EvaluationDomain>,
172        multiplicant: &impl GetGraphblasSparseMatrix,
173        accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
174        product: &mut impl GetGraphblasSparseMatrix,
175        mask: &impl MatrixMask,
176        options: &impl GetOptionsForOperatorWithMatrixArguments,
177    ) -> Result<(), SparseLinearAlgebraError> {
178        let context = product.context_ref();
179
180        context.call(
181            || unsafe {
182                GrB_Matrix_eWiseAdd_BinaryOp(
183                    product.graphblas_matrix_ptr(),
184                    mask.graphblas_matrix_ptr(),
185                    accumulator.accumulator_graphblas_type(),
186                    operator.graphblas_type(),
187                    multiplier.graphblas_matrix_ptr(),
188                    multiplicant.graphblas_matrix_ptr(),
189                    options.graphblas_descriptor(),
190                )
191            },
192            unsafe { &product.graphblas_matrix_ptr() },
193        )?;
194
195        Ok(())
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    use crate::collections::sparse_matrix::operations::{
204        FromMatrixElementList, GetSparseMatrixElementList, GetSparseMatrixElementValue,
205    };
206    use crate::collections::sparse_matrix::{MatrixElementList, Size, SparseMatrix};
207    use crate::collections::Collection;
208    use crate::context::Context;
209    use crate::operators::binary_operator::{Assignment, First, Plus, Times};
210    use crate::operators::mask::SelectEntireMatrix;
211    use crate::operators::options::OptionsForOperatorWithMatrixArguments;
212
213    #[test]
214    fn test_element_wise_multiplication() {
215        let context = Context::init_default().unwrap();
216
217        let operator = Times::<i32>::new();
218        let options = OptionsForOperatorWithMatrixArguments::new_default();
219        let element_wise_matrix_multiplier = ElementWiseMatrixAdditionBinaryOperator::new();
220
221        let height = 2;
222        let width = 2;
223        let size: Size = (height, width).into();
224
225        let multiplier = SparseMatrix::<i32>::new(context.clone(), size).unwrap();
226        let multiplicant = multiplier.clone();
227        let mut product = multiplier.clone();
228
229        // Test multiplication of empty matrices
230        element_wise_matrix_multiplier
231            .apply(
232                &multiplier,
233                &operator,
234                &multiplicant,
235                &Assignment::<i32>::new(),
236                &mut product,
237                &SelectEntireMatrix::new(context.clone()),
238                &options,
239            )
240            .unwrap();
241        let element_list = product.element_list().unwrap();
242
243        assert_eq!(product.number_of_stored_elements().unwrap(), 0);
244        assert_eq!(element_list.length(), 0);
245        assert_eq!(product.element_value(1, 1).unwrap(), None); // NoValue
246
247        let multiplier_element_list = MatrixElementList::<i32>::from_element_vector(vec![
248            (0, 0, 1).into(),
249            (1, 0, 2).into(),
250            (0, 1, 3).into(),
251            (1, 1, 4).into(),
252        ]);
253        let multiplier = SparseMatrix::<i32>::from_element_list(
254            context.clone(),
255            size,
256            multiplier_element_list,
257            &First::<i32>::new(),
258        )
259        .unwrap();
260
261        let multiplicant_element_list = MatrixElementList::<i32>::from_element_vector(vec![
262            (0, 0, 5).into(),
263            (1, 0, 6).into(),
264            (0, 1, 7).into(),
265            (1, 1, 8).into(),
266        ]);
267        let multiplicant = SparseMatrix::<i32>::from_element_list(
268            context.clone(),
269            size,
270            multiplicant_element_list,
271            &First::<i32>::new(),
272        )
273        .unwrap();
274
275        // Test multiplication of full matrices
276        element_wise_matrix_multiplier
277            .apply(
278                &multiplier,
279                &operator,
280                &multiplicant,
281                &Assignment::<i32>::new(),
282                &mut product,
283                &SelectEntireMatrix::new(context.clone()),
284                &options,
285            )
286            .unwrap();
287
288        assert_eq!(product.element_value_or_default(0, 0).unwrap(), 5);
289        assert_eq!(product.element_value_or_default(1, 0).unwrap(), 12);
290        assert_eq!(product.element_value_or_default(0, 1).unwrap(), 21);
291        assert_eq!(product.element_value_or_default(1, 1).unwrap(), 32);
292
293        // test the use of an accumulator
294        let accumulator = Plus::<i32>::new();
295        let matrix_multiplier_with_accumulator = ElementWiseMatrixAdditionBinaryOperator::new();
296
297        matrix_multiplier_with_accumulator
298            .apply(
299                &multiplier,
300                &operator,
301                &multiplicant,
302                &accumulator,
303                &mut product,
304                &SelectEntireMatrix::new(context.clone()),
305                &options,
306            )
307            .unwrap();
308
309        assert_eq!(product.element_value_or_default(0, 0).unwrap(), 5 * 2);
310        assert_eq!(product.element_value_or_default(1, 0).unwrap(), 12 * 2);
311        assert_eq!(product.element_value_or_default(0, 1).unwrap(), 21 * 2);
312        assert_eq!(product.element_value_or_default(1, 1).unwrap(), 32 * 2);
313
314        // test the use of a mask
315        let mask_element_list = MatrixElementList::<u8>::from_element_vector(vec![
316            (0, 0, 3).into(),
317            (1, 0, 0).into(),
318            (1, 1, 1).into(),
319        ]);
320        let mask = SparseMatrix::<u8>::from_element_list(
321            context.clone(),
322            size,
323            mask_element_list,
324            &First::<u8>::new(),
325        )
326        .unwrap();
327
328        let matrix_multiplier = ElementWiseMatrixAdditionBinaryOperator::new();
329
330        let mut product = SparseMatrix::<i32>::new(context, size).unwrap();
331
332        matrix_multiplier
333            .apply(
334                &multiplier,
335                &operator,
336                &multiplicant,
337                &accumulator,
338                &mut product,
339                &mask,
340                &options,
341            )
342            .unwrap();
343
344        assert_eq!(product.element_value_or_default(0, 0).unwrap(), 5);
345        assert_eq!(product.element_value(1, 0).unwrap(), None);
346        assert_eq!(product.element_value(0, 1).unwrap(), None);
347        assert_eq!(product.element_value_or_default(1, 1).unwrap(), 32);
348    }
349
350    #[test]
351    fn test_element_wise_addition() {
352        let context = Context::init_default().unwrap();
353
354        let operator = Plus::<i32>::new();
355        let options = OptionsForOperatorWithMatrixArguments::new_default();
356        let element_wise_matrix_adder = ElementWiseMatrixAdditionBinaryOperator::new();
357
358        let height = 2;
359        let width = 2;
360        let size: Size = (height, width).into();
361
362        let multiplier = SparseMatrix::<i32>::new(context.clone(), size).unwrap();
363        let multiplicant = multiplier.clone();
364        let mut product = multiplier.clone();
365
366        // Test multiplication of empty matrices
367        element_wise_matrix_adder
368            .apply(
369                &multiplier,
370                &operator,
371                &multiplicant,
372                &Assignment::<i32>::new(),
373                &mut product,
374                &SelectEntireMatrix::new(context.clone()),
375                &options,
376            )
377            .unwrap();
378        let element_list = product.element_list().unwrap();
379
380        assert_eq!(product.number_of_stored_elements().unwrap(), 0);
381        assert_eq!(element_list.length(), 0);
382        assert_eq!(product.element_value(1, 1).unwrap(), None); // NoValue
383
384        let multiplier_element_list = MatrixElementList::<i32>::from_element_vector(vec![
385            (0, 0, 1).into(),
386            (1, 0, 2).into(),
387            (0, 1, 3).into(),
388            (1, 1, 4).into(),
389        ]);
390        let multiplier = SparseMatrix::<i32>::from_element_list(
391            context.clone(),
392            size,
393            multiplier_element_list,
394            &First::<i32>::new(),
395        )
396        .unwrap();
397
398        let multiplicant_element_list = MatrixElementList::<i32>::from_element_vector(vec![
399            (0, 0, 5).into(),
400            (1, 0, 6).into(),
401            (0, 1, 7).into(),
402            (1, 1, 8).into(),
403        ]);
404        let multiplicant = SparseMatrix::<i32>::from_element_list(
405            context.clone(),
406            size,
407            multiplicant_element_list,
408            &First::<i32>::new(),
409        )
410        .unwrap();
411
412        // Test multiplication of full matrices
413        element_wise_matrix_adder
414            .apply(
415                &multiplier,
416                &operator,
417                &multiplicant,
418                &Assignment::new(),
419                &mut product,
420                &SelectEntireMatrix::new(context.clone()),
421                &options,
422            )
423            .unwrap();
424
425        assert_eq!(product.element_value_or_default(0, 0).unwrap(), 6);
426        assert_eq!(product.element_value_or_default(1, 0).unwrap(), 8);
427        assert_eq!(product.element_value_or_default(0, 1).unwrap(), 10);
428        assert_eq!(product.element_value_or_default(1, 1).unwrap(), 12);
429    }
430}