1use crate::collections::sparse_matrix::GetGraphblasSparseMatrix;
2use crate::context::CallGraphBlasContext;
3use crate::error::SparseLinearAlgebraError;
4use crate::operators::binary_operator::AccumulatorBinaryOperator;
5use crate::operators::mask::MatrixMask;
6use crate::operators::options::GetOptionsForOperatorWithMatrixArguments;
7use crate::operators::{binary_operator::BinaryOperator, monoid::Monoid, semiring::Semiring};
8use crate::value_type::ValueType;
9
10use crate::graphblas_bindings::{
11 GrB_Matrix_eWiseAdd_BinaryOp, GrB_Matrix_eWiseAdd_Monoid, GrB_Matrix_eWiseAdd_Semiring,
12};
13
14unsafe impl Sync for ElementWiseMatrixAdditionSemiringOperator {}
18unsafe impl Send for ElementWiseMatrixAdditionSemiringOperator {}
19
20#[derive(Debug, Clone)]
21pub struct ElementWiseMatrixAdditionSemiringOperator {}
22
23impl ElementWiseMatrixAdditionSemiringOperator {
24 pub fn new() -> Self {
25 Self {}
26 }
27}
28
29pub trait ApplyElementWiseMatrixAdditionSemiring<EvaluationDomain: ValueType> {
30 fn apply(
31 &self,
32 multiplier: &impl GetGraphblasSparseMatrix,
33 operator: &impl Semiring<EvaluationDomain>,
34 multiplicant: &impl GetGraphblasSparseMatrix,
35 accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
36 product: &mut impl GetGraphblasSparseMatrix,
37 mask: &impl MatrixMask,
38 options: &impl GetOptionsForOperatorWithMatrixArguments,
39 ) -> Result<(), SparseLinearAlgebraError>;
40}
41
42impl<EvaluationDomain: ValueType> ApplyElementWiseMatrixAdditionSemiring<EvaluationDomain>
43 for ElementWiseMatrixAdditionSemiringOperator
44{
45 fn apply(
46 &self,
47 multiplier: &impl GetGraphblasSparseMatrix,
48 operator: &impl Semiring<EvaluationDomain>,
49 multiplicant: &impl GetGraphblasSparseMatrix,
50 accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
51 product: &mut impl GetGraphblasSparseMatrix,
52 mask: &impl MatrixMask,
53 options: &impl GetOptionsForOperatorWithMatrixArguments,
54 ) -> Result<(), SparseLinearAlgebraError> {
55 let context = product.context_ref();
56
57 context.call(
58 || unsafe {
59 GrB_Matrix_eWiseAdd_Semiring(
60 product.graphblas_matrix_ptr(),
61 mask.graphblas_matrix_ptr(),
62 accumulator.accumulator_graphblas_type(),
63 operator.graphblas_type(),
64 multiplier.graphblas_matrix_ptr(),
65 multiplicant.graphblas_matrix_ptr(),
66 options.graphblas_descriptor(),
67 )
68 },
69 unsafe { &product.graphblas_matrix_ptr() },
70 )?;
71 Ok(())
72 }
73}
74
75unsafe impl Sync for ElementWiseMatrixAdditionMonoidOperator {}
79unsafe impl Send for ElementWiseMatrixAdditionMonoidOperator {}
80
81#[derive(Debug, Clone)]
82pub struct ElementWiseMatrixAdditionMonoidOperator {}
83
84impl ElementWiseMatrixAdditionMonoidOperator {
85 pub fn new() -> Self {
86 Self {}
87 }
88}
89
90pub trait ApplyElementWiseMatrixAdditionMonoidOperator<EvaluationDomain: ValueType> {
91 fn apply(
92 &self,
93 multiplier: &impl GetGraphblasSparseMatrix,
94 operator: &impl Monoid<EvaluationDomain>,
95 multiplicant: &impl GetGraphblasSparseMatrix,
96 accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
97 product: &mut impl GetGraphblasSparseMatrix,
98 mask: &impl MatrixMask,
99 options: &impl GetOptionsForOperatorWithMatrixArguments,
100 ) -> Result<(), SparseLinearAlgebraError>;
101}
102
103impl<EvaluationDomain: ValueType> ApplyElementWiseMatrixAdditionMonoidOperator<EvaluationDomain>
104 for ElementWiseMatrixAdditionMonoidOperator
105{
106 fn apply(
107 &self,
108 multiplier: &impl GetGraphblasSparseMatrix,
109 operator: &impl Monoid<EvaluationDomain>,
110 multiplicant: &impl GetGraphblasSparseMatrix,
111 accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
112 product: &mut impl GetGraphblasSparseMatrix,
113 mask: &impl MatrixMask,
114 options: &impl GetOptionsForOperatorWithMatrixArguments,
115 ) -> Result<(), SparseLinearAlgebraError> {
116 let context = product.context_ref();
117
118 context.call(
119 || unsafe {
120 GrB_Matrix_eWiseAdd_Monoid(
121 product.graphblas_matrix_ptr(),
122 mask.graphblas_matrix_ptr(),
123 accumulator.accumulator_graphblas_type(),
124 operator.graphblas_type(),
125 multiplier.graphblas_matrix_ptr(),
126 multiplicant.graphblas_matrix_ptr(),
127 options.graphblas_descriptor(),
128 )
129 },
130 unsafe { &product.graphblas_matrix_ptr() },
131 )?;
132
133 Ok(())
134 }
135}
136
137unsafe impl Sync for ElementWiseMatrixAdditionBinaryOperator {}
141unsafe impl Send for ElementWiseMatrixAdditionBinaryOperator {}
142
143#[derive(Debug, Clone)]
144pub struct ElementWiseMatrixAdditionBinaryOperator {}
145
146impl ElementWiseMatrixAdditionBinaryOperator {
147 pub fn new() -> Self {
148 Self {}
149 }
150}
151
152pub trait ApplyElementWiseMatrixAdditionBinaryOperator<EvaluationDomain: ValueType> {
153 fn apply(
154 &self,
155 multiplier: &impl GetGraphblasSparseMatrix,
156 operator: &impl BinaryOperator<EvaluationDomain>,
157 multiplicant: &impl GetGraphblasSparseMatrix,
158 accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
159 product: &mut impl GetGraphblasSparseMatrix,
160 mask: &impl MatrixMask,
161 options: &impl GetOptionsForOperatorWithMatrixArguments,
162 ) -> Result<(), SparseLinearAlgebraError>;
163}
164
165impl<EvaluationDomain: ValueType> ApplyElementWiseMatrixAdditionBinaryOperator<EvaluationDomain>
166 for ElementWiseMatrixAdditionBinaryOperator
167{
168 fn apply(
169 &self,
170 multiplier: &impl GetGraphblasSparseMatrix,
171 operator: &impl BinaryOperator<EvaluationDomain>,
172 multiplicant: &impl GetGraphblasSparseMatrix,
173 accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
174 product: &mut impl GetGraphblasSparseMatrix,
175 mask: &impl MatrixMask,
176 options: &impl GetOptionsForOperatorWithMatrixArguments,
177 ) -> Result<(), SparseLinearAlgebraError> {
178 let context = product.context_ref();
179
180 context.call(
181 || unsafe {
182 GrB_Matrix_eWiseAdd_BinaryOp(
183 product.graphblas_matrix_ptr(),
184 mask.graphblas_matrix_ptr(),
185 accumulator.accumulator_graphblas_type(),
186 operator.graphblas_type(),
187 multiplier.graphblas_matrix_ptr(),
188 multiplicant.graphblas_matrix_ptr(),
189 options.graphblas_descriptor(),
190 )
191 },
192 unsafe { &product.graphblas_matrix_ptr() },
193 )?;
194
195 Ok(())
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 use crate::collections::sparse_matrix::operations::{
204 FromMatrixElementList, GetSparseMatrixElementList, GetSparseMatrixElementValue,
205 };
206 use crate::collections::sparse_matrix::{MatrixElementList, Size, SparseMatrix};
207 use crate::collections::Collection;
208 use crate::context::Context;
209 use crate::operators::binary_operator::{Assignment, First, Plus, Times};
210 use crate::operators::mask::SelectEntireMatrix;
211 use crate::operators::options::OptionsForOperatorWithMatrixArguments;
212
213 #[test]
214 fn test_element_wise_multiplication() {
215 let context = Context::init_default().unwrap();
216
217 let operator = Times::<i32>::new();
218 let options = OptionsForOperatorWithMatrixArguments::new_default();
219 let element_wise_matrix_multiplier = ElementWiseMatrixAdditionBinaryOperator::new();
220
221 let height = 2;
222 let width = 2;
223 let size: Size = (height, width).into();
224
225 let multiplier = SparseMatrix::<i32>::new(context.clone(), size).unwrap();
226 let multiplicant = multiplier.clone();
227 let mut product = multiplier.clone();
228
229 element_wise_matrix_multiplier
231 .apply(
232 &multiplier,
233 &operator,
234 &multiplicant,
235 &Assignment::<i32>::new(),
236 &mut product,
237 &SelectEntireMatrix::new(context.clone()),
238 &options,
239 )
240 .unwrap();
241 let element_list = product.element_list().unwrap();
242
243 assert_eq!(product.number_of_stored_elements().unwrap(), 0);
244 assert_eq!(element_list.length(), 0);
245 assert_eq!(product.element_value(1, 1).unwrap(), None); let multiplier_element_list = MatrixElementList::<i32>::from_element_vector(vec![
248 (0, 0, 1).into(),
249 (1, 0, 2).into(),
250 (0, 1, 3).into(),
251 (1, 1, 4).into(),
252 ]);
253 let multiplier = SparseMatrix::<i32>::from_element_list(
254 context.clone(),
255 size,
256 multiplier_element_list,
257 &First::<i32>::new(),
258 )
259 .unwrap();
260
261 let multiplicant_element_list = MatrixElementList::<i32>::from_element_vector(vec![
262 (0, 0, 5).into(),
263 (1, 0, 6).into(),
264 (0, 1, 7).into(),
265 (1, 1, 8).into(),
266 ]);
267 let multiplicant = SparseMatrix::<i32>::from_element_list(
268 context.clone(),
269 size,
270 multiplicant_element_list,
271 &First::<i32>::new(),
272 )
273 .unwrap();
274
275 element_wise_matrix_multiplier
277 .apply(
278 &multiplier,
279 &operator,
280 &multiplicant,
281 &Assignment::<i32>::new(),
282 &mut product,
283 &SelectEntireMatrix::new(context.clone()),
284 &options,
285 )
286 .unwrap();
287
288 assert_eq!(product.element_value_or_default(0, 0).unwrap(), 5);
289 assert_eq!(product.element_value_or_default(1, 0).unwrap(), 12);
290 assert_eq!(product.element_value_or_default(0, 1).unwrap(), 21);
291 assert_eq!(product.element_value_or_default(1, 1).unwrap(), 32);
292
293 let accumulator = Plus::<i32>::new();
295 let matrix_multiplier_with_accumulator = ElementWiseMatrixAdditionBinaryOperator::new();
296
297 matrix_multiplier_with_accumulator
298 .apply(
299 &multiplier,
300 &operator,
301 &multiplicant,
302 &accumulator,
303 &mut product,
304 &SelectEntireMatrix::new(context.clone()),
305 &options,
306 )
307 .unwrap();
308
309 assert_eq!(product.element_value_or_default(0, 0).unwrap(), 5 * 2);
310 assert_eq!(product.element_value_or_default(1, 0).unwrap(), 12 * 2);
311 assert_eq!(product.element_value_or_default(0, 1).unwrap(), 21 * 2);
312 assert_eq!(product.element_value_or_default(1, 1).unwrap(), 32 * 2);
313
314 let mask_element_list = MatrixElementList::<u8>::from_element_vector(vec![
316 (0, 0, 3).into(),
317 (1, 0, 0).into(),
318 (1, 1, 1).into(),
319 ]);
320 let mask = SparseMatrix::<u8>::from_element_list(
321 context.clone(),
322 size,
323 mask_element_list,
324 &First::<u8>::new(),
325 )
326 .unwrap();
327
328 let matrix_multiplier = ElementWiseMatrixAdditionBinaryOperator::new();
329
330 let mut product = SparseMatrix::<i32>::new(context, size).unwrap();
331
332 matrix_multiplier
333 .apply(
334 &multiplier,
335 &operator,
336 &multiplicant,
337 &accumulator,
338 &mut product,
339 &mask,
340 &options,
341 )
342 .unwrap();
343
344 assert_eq!(product.element_value_or_default(0, 0).unwrap(), 5);
345 assert_eq!(product.element_value(1, 0).unwrap(), None);
346 assert_eq!(product.element_value(0, 1).unwrap(), None);
347 assert_eq!(product.element_value_or_default(1, 1).unwrap(), 32);
348 }
349
350 #[test]
351 fn test_element_wise_addition() {
352 let context = Context::init_default().unwrap();
353
354 let operator = Plus::<i32>::new();
355 let options = OptionsForOperatorWithMatrixArguments::new_default();
356 let element_wise_matrix_adder = ElementWiseMatrixAdditionBinaryOperator::new();
357
358 let height = 2;
359 let width = 2;
360 let size: Size = (height, width).into();
361
362 let multiplier = SparseMatrix::<i32>::new(context.clone(), size).unwrap();
363 let multiplicant = multiplier.clone();
364 let mut product = multiplier.clone();
365
366 element_wise_matrix_adder
368 .apply(
369 &multiplier,
370 &operator,
371 &multiplicant,
372 &Assignment::<i32>::new(),
373 &mut product,
374 &SelectEntireMatrix::new(context.clone()),
375 &options,
376 )
377 .unwrap();
378 let element_list = product.element_list().unwrap();
379
380 assert_eq!(product.number_of_stored_elements().unwrap(), 0);
381 assert_eq!(element_list.length(), 0);
382 assert_eq!(product.element_value(1, 1).unwrap(), None); let multiplier_element_list = MatrixElementList::<i32>::from_element_vector(vec![
385 (0, 0, 1).into(),
386 (1, 0, 2).into(),
387 (0, 1, 3).into(),
388 (1, 1, 4).into(),
389 ]);
390 let multiplier = SparseMatrix::<i32>::from_element_list(
391 context.clone(),
392 size,
393 multiplier_element_list,
394 &First::<i32>::new(),
395 )
396 .unwrap();
397
398 let multiplicant_element_list = MatrixElementList::<i32>::from_element_vector(vec![
399 (0, 0, 5).into(),
400 (1, 0, 6).into(),
401 (0, 1, 7).into(),
402 (1, 1, 8).into(),
403 ]);
404 let multiplicant = SparseMatrix::<i32>::from_element_list(
405 context.clone(),
406 size,
407 multiplicant_element_list,
408 &First::<i32>::new(),
409 )
410 .unwrap();
411
412 element_wise_matrix_adder
414 .apply(
415 &multiplier,
416 &operator,
417 &multiplicant,
418 &Assignment::new(),
419 &mut product,
420 &SelectEntireMatrix::new(context.clone()),
421 &options,
422 )
423 .unwrap();
424
425 assert_eq!(product.element_value_or_default(0, 0).unwrap(), 6);
426 assert_eq!(product.element_value_or_default(1, 0).unwrap(), 8);
427 assert_eq!(product.element_value_or_default(0, 1).unwrap(), 10);
428 assert_eq!(product.element_value_or_default(1, 1).unwrap(), 12);
429 }
430}