use crate::collections::sparse_matrix::GetGraphblasSparseMatrix;
use crate::context::CallGraphBlasContext;
use crate::error::SparseLinearAlgebraError;
use crate::operators::{binary_operator::BinaryOperator, monoid::Monoid, semiring::Semiring};
use crate::value_type::ValueType;
use crate::graphblas_bindings::{
GrB_Matrix_kronecker_BinaryOp, GrB_Matrix_kronecker_Monoid, GrB_Matrix_kronecker_Semiring,
};
use super::binary_operator::AccumulatorBinaryOperator;
use super::mask::MatrixMask;
use super::options::GetOptionsForOperatorWithMatrixArguments;
unsafe impl Send for SemiringKroneckerProductOperator {}
unsafe impl Sync for SemiringKroneckerProductOperator {}
#[derive(Debug, Clone)]
pub struct SemiringKroneckerProductOperator {}
impl SemiringKroneckerProductOperator {
pub fn new() -> Self {
Self {}
}
}
pub trait SemiringKroneckerProduct<EvaluationDomain: ValueType> {
fn apply(
&self,
multiplier: &impl GetGraphblasSparseMatrix,
multiplication_operator: &impl Semiring<EvaluationDomain>,
multiplicant: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseMatrix,
mask: &impl MatrixMask,
options: &impl GetOptionsForOperatorWithMatrixArguments,
) -> Result<(), SparseLinearAlgebraError>;
}
impl<EvaluationDomain: ValueType> SemiringKroneckerProduct<EvaluationDomain>
for SemiringKroneckerProductOperator
{
fn apply(
&self,
multiplier: &impl GetGraphblasSparseMatrix,
multiplication_operator: &impl Semiring<EvaluationDomain>,
multiplicant: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseMatrix,
mask: &impl MatrixMask,
options: &impl GetOptionsForOperatorWithMatrixArguments,
) -> Result<(), SparseLinearAlgebraError> {
let context = product.context_ref();
context.call(
|| unsafe {
GrB_Matrix_kronecker_Semiring(
product.graphblas_matrix(),
mask.graphblas_matrix(),
accumulator.accumulator_graphblas_type(),
multiplication_operator.graphblas_type(),
multiplier.graphblas_matrix(),
multiplicant.graphblas_matrix(),
options.graphblas_descriptor(),
)
},
unsafe { product.graphblas_matrix_ref() },
)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MonoidKroneckerProductOperator {}
impl MonoidKroneckerProductOperator {
pub fn new() -> Self {
Self {}
}
}
pub trait MonoidKroneckerProduct<EvaluationDomain: ValueType> {
fn apply(
&self,
multiplier: &impl GetGraphblasSparseMatrix,
multiplication_operator: &impl Monoid<EvaluationDomain>,
multiplicant: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseMatrix,
mask: &impl MatrixMask,
options: &impl GetOptionsForOperatorWithMatrixArguments,
) -> Result<(), SparseLinearAlgebraError>;
}
impl<EvaluationDomain: ValueType> MonoidKroneckerProduct<EvaluationDomain>
for MonoidKroneckerProductOperator
{
fn apply(
&self,
multiplier: &impl GetGraphblasSparseMatrix,
multiplication_operator: &impl Monoid<EvaluationDomain>,
multiplicant: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseMatrix,
mask: &impl MatrixMask,
options: &impl GetOptionsForOperatorWithMatrixArguments,
) -> Result<(), SparseLinearAlgebraError> {
let context = product.context_ref();
context.call(
|| unsafe {
GrB_Matrix_kronecker_Monoid(
product.graphblas_matrix(),
mask.graphblas_matrix(),
accumulator.accumulator_graphblas_type(),
multiplication_operator.graphblas_type(),
multiplier.graphblas_matrix(),
multiplicant.graphblas_matrix(),
options.graphblas_descriptor(),
)
},
unsafe { product.graphblas_matrix_ref() },
)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct BinaryOperatorKroneckerProductOperator {}
impl BinaryOperatorKroneckerProductOperator {
pub fn new() -> Self {
Self {}
}
}
pub trait BinaryOperatorKroneckerProduct<EvaluationDomain: ValueType> {
fn apply(
&self,
multiplier: &impl GetGraphblasSparseMatrix,
multiplication_operator: &impl BinaryOperator<EvaluationDomain>,
multiplicant: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseMatrix,
mask: &impl MatrixMask,
options: &impl GetOptionsForOperatorWithMatrixArguments,
) -> Result<(), SparseLinearAlgebraError>;
}
impl<EvaluationDomain: ValueType> BinaryOperatorKroneckerProduct<EvaluationDomain>
for BinaryOperatorKroneckerProductOperator
{
fn apply(
&self,
multiplier: &impl GetGraphblasSparseMatrix,
multiplication_operator: &impl BinaryOperator<EvaluationDomain>,
multiplicant: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseMatrix,
mask: &impl MatrixMask,
options: &impl GetOptionsForOperatorWithMatrixArguments,
) -> Result<(), SparseLinearAlgebraError> {
let context = product.context_ref();
context.call(
|| unsafe {
GrB_Matrix_kronecker_BinaryOp(
product.graphblas_matrix(),
mask.graphblas_matrix(),
accumulator.accumulator_graphblas_type(),
multiplication_operator.graphblas_type(),
multiplier.graphblas_matrix(),
multiplicant.graphblas_matrix(),
options.graphblas_descriptor(),
)
},
unsafe { product.graphblas_matrix_ref() },
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::collections::sparse_matrix::operations::{
FromMatrixElementList, GetSparseMatrixElementList, GetSparseMatrixElementValue,
};
use crate::collections::Collection;
use crate::context::Context;
use crate::operators::binary_operator::{Assignment, First, Times};
use crate::collections::sparse_matrix::{MatrixElementList, Size, SparseMatrix};
use crate::operators::mask::SelectEntireMatrix;
use crate::operators::options::OptionsForOperatorWithMatrixArguments;
#[test]
fn test_element_wisemultiplication() {
let context = Context::init_default().unwrap();
let operator = Times::<i32>::new();
let options = OptionsForOperatorWithMatrixArguments::new_default();
let element_wise_matrix_multiplier = BinaryOperatorKroneckerProductOperator::new();
let height = 2;
let width = 2;
let size: Size = (height, width).into();
let multiplier = SparseMatrix::<i32>::new(context.clone(), size).unwrap();
let multiplicant = multiplier.clone();
let mut product = SparseMatrix::<i32>::new(context.clone(), (4, 4).into()).unwrap();
element_wise_matrix_multiplier
.apply(
&multiplier,
&operator,
&multiplicant,
&Assignment::new(),
&mut product,
&SelectEntireMatrix::new(context.clone()),
&options,
)
.unwrap();
let element_list = product.element_list().unwrap();
assert_eq!(product.number_of_stored_elements().unwrap(), 0);
assert_eq!(element_list.length(), 0);
assert_eq!(product.element_value(1, 1).unwrap(), None);
let multiplier_element_list = MatrixElementList::<i32>::from_element_vector(vec![
(0, 0, 1).into(),
(1, 0, 2).into(),
(0, 1, 3).into(),
(1, 1, 4).into(),
]);
let multiplier = SparseMatrix::<i32>::from_element_list(
context.clone(),
size,
multiplier_element_list,
&First::<i32>::new(),
)
.unwrap();
let multiplicant_element_list = MatrixElementList::<i32>::from_element_vector(vec![
(0, 0, 5).into(),
(1, 0, 6).into(),
(0, 1, 7).into(),
(1, 1, 8).into(),
]);
let multiplicant = SparseMatrix::<i32>::from_element_list(
context.clone(),
size,
multiplicant_element_list,
&First::<i32>::new(),
)
.unwrap();
element_wise_matrix_multiplier
.apply(
&multiplier,
&operator,
&multiplicant,
&Assignment::new(),
&mut product,
&SelectEntireMatrix::new(context.clone()),
&options,
)
.unwrap();
assert_eq!(product.element_value_or_default(0, 0).unwrap(), 5);
assert_eq!(product.element_value_or_default(1, 0).unwrap(), 6);
assert_eq!(product.element_value_or_default(0, 1).unwrap(), 7);
assert_eq!(product.element_value_or_default(1, 1).unwrap(), 8);
assert_eq!(product.element_value_or_default(2, 0).unwrap(), 10);
assert_eq!(product.element_value_or_default(3, 0).unwrap(), 12);
assert_eq!(product.element_value_or_default(2, 1).unwrap(), 14);
assert_eq!(product.element_value_or_default(3, 1).unwrap(), 16);
assert_eq!(product.element_value_or_default(0, 2).unwrap(), 15);
assert_eq!(product.element_value_or_default(1, 2).unwrap(), 18);
assert_eq!(product.element_value_or_default(0, 3).unwrap(), 21);
assert_eq!(product.element_value_or_default(1, 3).unwrap(), 24);
assert_eq!(product.element_value_or_default(2, 2).unwrap(), 20);
assert_eq!(product.element_value_or_default(3, 2).unwrap(), 24);
assert_eq!(product.element_value_or_default(2, 3).unwrap(), 28);
assert_eq!(product.element_value_or_default(3, 3).unwrap(), 32);
}
}