use crate::collections::sparse_matrix::GetGraphblasSparseMatrix;
use crate::collections::sparse_vector::GetGraphblasSparseVector;
use crate::context::CallGraphBlasContext;
use crate::error::SparseLinearAlgebraError;
use crate::operators::binary_operator::AccumulatorBinaryOperator;
use crate::operators::binary_operator::BinaryOperator;
use crate::operators::mask::VectorMask;
use crate::operators::options::{
GetOptionsForOperatorWithMatrixArgument, WithTransposeMatrixArgument,
};
use crate::value_type::ValueType;
use crate::graphblas_bindings::GrB_Matrix_reduce_BinaryOp;
unsafe impl Send for BinaryOperatorReducer {}
unsafe impl Sync for BinaryOperatorReducer {}
#[derive(Debug, Clone)]
pub struct BinaryOperatorReducer {}
impl BinaryOperatorReducer {
pub fn new() -> Self {
Self {}
}
}
pub trait ReduceWithBinaryOperator<EvaluationDomain: ValueType> {
fn to_colunm_vector(
&self,
operator: &impl BinaryOperator<EvaluationDomain>,
argument: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseVector,
mask: &impl VectorMask,
options: &impl GetOptionsForOperatorWithMatrixArgument,
) -> Result<(), SparseLinearAlgebraError>;
fn to_row_vector(
&self,
operator: &impl BinaryOperator<EvaluationDomain>,
argument: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseVector,
mask: &impl VectorMask,
options: &(impl GetOptionsForOperatorWithMatrixArgument + WithTransposeMatrixArgument),
) -> Result<(), SparseLinearAlgebraError>;
}
impl<EvaluationDomain: ValueType> ReduceWithBinaryOperator<EvaluationDomain>
for BinaryOperatorReducer
{
fn to_colunm_vector(
&self,
operator: &impl BinaryOperator<EvaluationDomain>,
argument: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseVector,
mask: &impl VectorMask,
options: &impl GetOptionsForOperatorWithMatrixArgument,
) -> Result<(), SparseLinearAlgebraError> {
let context = product.context_ref();
context.call(
|| unsafe {
GrB_Matrix_reduce_BinaryOp(
product.graphblas_vector_ptr(),
mask.graphblas_vector_ptr(),
accumulator.accumulator_graphblas_type(),
operator.graphblas_type(),
argument.graphblas_matrix_ptr(),
options.graphblas_descriptor(),
)
},
unsafe { product.graphblas_vector_ptr_ref() },
)?;
Ok(())
}
fn to_row_vector(
&self,
operator: &impl BinaryOperator<EvaluationDomain>,
argument: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseVector,
mask: &impl VectorMask,
options: &(impl GetOptionsForOperatorWithMatrixArgument + WithTransposeMatrixArgument),
) -> Result<(), SparseLinearAlgebraError> {
self.to_colunm_vector(
operator,
argument,
accumulator,
product,
mask,
&options.with_negated_transpose_matrix_argument(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::collections::sparse_matrix::operations::FromMatrixElementList;
use crate::collections::sparse_vector::operations::{
FromVectorElementList, GetSparseVectorElementValue,
};
use crate::collections::Collection;
use crate::context::Context;
use crate::operators::binary_operator::{Assignment, First, Plus};
use crate::collections::sparse_matrix::{
GetMatrixDimensions, MatrixElementList, Size, SparseMatrix,
};
use crate::collections::sparse_vector::{SparseVector, VectorElementList};
use crate::operators::mask::SelectEntireVector;
use crate::operators::options::OptionsForOperatorWithMatrixArgument;
#[test]
fn test_binary_operator_reducer() {
let context = Context::init_default().unwrap();
let element_list = MatrixElementList::<u8>::from_element_vector(vec![
(1, 1, 1).into(),
(1, 5, 1).into(),
(2, 1, 2).into(),
(4, 2, 4).into(),
(5, 2, 5).into(),
]);
let matrix_size: Size = (10, 15).into();
let matrix = SparseMatrix::<u8>::from_element_list(
context.clone(),
matrix_size,
element_list,
&First::<u8>::new(),
)
.unwrap();
let mut product_vector =
SparseVector::<u8>::new(context.clone(), matrix_size.row_height()).unwrap();
let reducer = BinaryOperatorReducer::new();
reducer
.to_colunm_vector(
&Plus::<u8>::new(),
&matrix,
&Assignment::new(),
&mut product_vector,
&SelectEntireVector::new(context.clone()),
&OptionsForOperatorWithMatrixArgument::new_default(),
)
.unwrap();
println!("{}", product_vector);
assert_eq!(product_vector.number_of_stored_elements().unwrap(), 4);
assert_eq!(product_vector.element_value_or_default(1).unwrap(), 2);
assert_eq!(product_vector.element_value_or_default(2).unwrap(), 2);
assert_eq!(product_vector.element_value(9).unwrap(), None);
let mask_element_list = VectorElementList::<u8>::from_element_vector(vec![
(1, 1).into(),
(2, 2).into(),
(4, 4).into(),
]);
let mask = SparseVector::<u8>::from_element_list(
context.clone(),
matrix_size.row_height(),
mask_element_list,
&First::<u8>::new(),
)
.unwrap();
let mut product_vector =
SparseVector::<u8>::new(context.clone(), matrix_size.row_height()).unwrap();
reducer
.to_colunm_vector(
&Plus::<u8>::new(),
&matrix,
&Assignment::new(),
&mut product_vector,
&mask,
&OptionsForOperatorWithMatrixArgument::new_default(),
)
.unwrap();
println!("{}", matrix);
println!("{}", product_vector);
assert_eq!(product_vector.number_of_stored_elements().unwrap(), 3);
assert_eq!(product_vector.element_value_or_default(1).unwrap(), 2);
assert_eq!(product_vector.element_value_or_default(2).unwrap(), 2);
assert_eq!(product_vector.element_value(5).unwrap(), None);
assert_eq!(product_vector.element_value(9).unwrap(), None);
}
}