use crate::collections::sparse_matrix::GetGraphblasSparseMatrix;
use crate::collections::sparse_vector::GetGraphblasSparseVector;
use crate::context::CallGraphBlasContext;
use crate::error::SparseLinearAlgebraError;
use crate::operators::binary_operator::AccumulatorBinaryOperator;
use crate::operators::mask::{MatrixMask, VectorMask};
use crate::operators::options::{GetOperatorOptions, GetOptionsForOperatorWithMatrixArgument};
use crate::operators::unary_operator::UnaryOperator;
use crate::value_type::ValueType;
use crate::graphblas_bindings::{GrB_Matrix_apply, GrB_Vector_apply};
unsafe impl Send for UnaryOperatorApplier {}
unsafe impl Sync for UnaryOperatorApplier {}
#[derive(Debug, Clone)]
pub struct UnaryOperatorApplier {}
impl UnaryOperatorApplier {
pub fn new() -> Self {
Self {}
}
}
pub trait ApplyUnaryOperator<EvaluationDomain>
where
EvaluationDomain: ValueType,
{
fn apply_to_vector(
&self,
operator: &impl UnaryOperator<EvaluationDomain>,
argument: &impl GetGraphblasSparseVector,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseVector,
mask: &impl VectorMask,
options: &impl GetOperatorOptions,
) -> Result<(), SparseLinearAlgebraError>;
fn apply_to_matrix(
&self,
operator: &impl UnaryOperator<EvaluationDomain>,
argument: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseMatrix,
mask: &impl MatrixMask,
options: &impl GetOptionsForOperatorWithMatrixArgument,
) -> Result<(), SparseLinearAlgebraError>;
}
impl<EvaluationDomain: ValueType> ApplyUnaryOperator<EvaluationDomain> for UnaryOperatorApplier {
fn apply_to_vector(
&self,
operator: &impl UnaryOperator<EvaluationDomain>,
argument: &impl GetGraphblasSparseVector,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseVector,
mask: &impl VectorMask,
options: &impl GetOperatorOptions,
) -> Result<(), SparseLinearAlgebraError> {
let context = argument.context_ref();
context.call(
|| unsafe {
GrB_Vector_apply(
product.graphblas_vector_ptr(),
mask.graphblas_vector_ptr(),
accumulator.accumulator_graphblas_type(),
operator.graphblas_type(),
argument.graphblas_vector_ptr(),
options.graphblas_descriptor(),
)
},
unsafe { &product.graphblas_vector_ptr() },
)?;
Ok(())
}
fn apply_to_matrix(
&self,
operator: &impl UnaryOperator<EvaluationDomain>,
argument: &impl GetGraphblasSparseMatrix,
accumulator: &impl AccumulatorBinaryOperator<EvaluationDomain>,
product: &mut impl GetGraphblasSparseMatrix,
mask: &impl MatrixMask,
options: &impl GetOptionsForOperatorWithMatrixArgument,
) -> Result<(), SparseLinearAlgebraError> {
let context = argument.context_ref();
context.call(
|| unsafe {
GrB_Matrix_apply(
product.graphblas_matrix_ptr(),
mask.graphblas_matrix_ptr(),
accumulator.accumulator_graphblas_type(),
operator.graphblas_type(),
argument.graphblas_matrix_ptr(),
options.graphblas_descriptor(),
)
},
unsafe { &product.graphblas_matrix_ptr() },
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::collections::sparse_matrix::operations::{
FromMatrixElementList, GetSparseMatrixElementValue,
};
use crate::collections::sparse_matrix::{MatrixElementList, Size, SparseMatrix};
use crate::collections::sparse_vector::operations::{
FromVectorElementList, GetSparseVectorElementValue,
};
use crate::collections::sparse_vector::{SparseVector, VectorElementList};
use crate::collections::Collection;
use crate::context::Context;
use crate::operators::binary_operator::{Assignment, First};
use crate::operators::mask::{SelectEntireMatrix, SelectEntireVector};
use crate::operators::options::{OperatorOptions, OptionsForOperatorWithMatrixArgument};
use crate::operators::unary_operator::{Identity, LogicalNegation, One};
#[test]
fn test_matrix_unary_operator() {
let context = Context::init_default().unwrap();
let element_list = MatrixElementList::<u8>::from_element_vector(vec![
(1, 1, 1).into(),
(2, 1, 2).into(),
(4, 2, 4).into(),
(5, 2, 5).into(),
]);
let matrix_size: Size = (10, 15).into();
let matrix = SparseMatrix::<u8>::from_element_list(
context.clone(),
matrix_size,
element_list,
&First::<u8>::new(),
)
.unwrap();
let mut product_matrix = SparseMatrix::<u8>::new(context.clone(), matrix_size).unwrap();
let operator = UnaryOperatorApplier::new();
operator
.apply_to_matrix(
&One::<u8>::new(),
&matrix,
&Assignment::<u8>::new(),
&mut product_matrix,
&SelectEntireMatrix::new(context.clone()),
&OptionsForOperatorWithMatrixArgument::new_default(),
)
.unwrap();
println!("{}", product_matrix);
assert_eq!(product_matrix.number_of_stored_elements().unwrap(), 4);
assert_eq!(product_matrix.element_value_or_default(2, 1).unwrap(), 1);
assert_eq!(product_matrix.element_value(9, 1).unwrap(), None);
let operator = UnaryOperatorApplier::new();
operator
.apply_to_matrix(
&Identity::<u8>::new(),
&matrix,
&Assignment::<u8>::new(),
&mut product_matrix,
&SelectEntireMatrix::new(context.clone()),
&OptionsForOperatorWithMatrixArgument::new_default(),
)
.unwrap();
println!("{}", matrix);
println!("{}", product_matrix);
assert_eq!(product_matrix.number_of_stored_elements().unwrap(), 4);
assert_eq!(product_matrix.element_value_or_default(2, 1).unwrap(), 2);
assert_eq!(product_matrix.element_value(9, 1).unwrap(), None);
}
#[test]
fn test_vector_unary_operator() {
let context = Context::init_default().unwrap();
let element_list = VectorElementList::<u8>::from_element_vector(vec![
(1, 1).into(),
(2, 2).into(),
(4, 4).into(),
(5, 5).into(),
]);
let vector_length: usize = 10;
let vector = SparseVector::<u8>::from_element_list(
context.clone(),
vector_length,
element_list,
&First::<u8>::new(),
)
.unwrap();
let mut product_vector = SparseVector::<u8>::new(context.clone(), vector_length).unwrap();
let operator = UnaryOperatorApplier::new();
operator
.apply_to_vector(
&One::<u8>::new(),
&vector,
&Assignment::<u8>::new(),
&mut product_vector,
&SelectEntireVector::new(context.clone()),
&OperatorOptions::new_default(),
)
.unwrap();
println!("{}", product_vector);
assert_eq!(product_vector.number_of_stored_elements().unwrap(), 4);
assert_eq!(product_vector.element_value_or_default(2).unwrap(), 1);
assert_eq!(product_vector.element_value(9).unwrap(), None);
let operator = UnaryOperatorApplier::new();
operator
.apply_to_vector(
&Identity::<u8>::new(),
&vector,
&Assignment::<u8>::new(),
&mut product_vector,
&SelectEntireVector::new(context.clone()),
&OperatorOptions::new_default(),
)
.unwrap();
println!("{}", vector);
println!("{}", product_vector);
assert_eq!(product_vector.number_of_stored_elements().unwrap(), 4);
assert_eq!(product_vector.element_value_or_default(2).unwrap(), 2);
assert_eq!(product_vector.element_value(9).unwrap(), None);
}
#[test]
fn test_vector_unary_negation_operator() {
let context = Context::init_default().unwrap();
let vector_length: usize = 10;
let vector = SparseVector::<bool>::new(context.clone(), vector_length).unwrap();
let mut product_vector = SparseVector::<bool>::new(context.clone(), vector_length).unwrap();
let operator = UnaryOperatorApplier::new();
operator
.apply_to_vector(
&LogicalNegation::<bool>::new(),
&vector,
&Assignment::<bool>::new(),
&mut product_vector,
&SelectEntireVector::new(context),
&OperatorOptions::new_default(),
)
.unwrap();
println!("{}", product_vector);
assert_eq!(product_vector.number_of_stored_elements().unwrap(), 0);
}
}