graphblas_sparse_linear_algebra/operators/subinsert/
insert_vector_into_sub_vector.rs

1use crate::collections::sparse_vector::operations::GetSparseVectorLength;
2use crate::collections::sparse_vector::{GetGraphblasSparseVector, SparseVector};
3use crate::context::{CallGraphBlasContext, GetContext};
4use crate::error::SparseLinearAlgebraError;
5use crate::graphblas_bindings::GxB_Vector_subassign;
6use crate::index::{ElementIndexSelector, ElementIndexSelectorGraphblasType, IndexConversion};
7use crate::operators::binary_operator::AccumulatorBinaryOperator;
8use crate::operators::mask::VectorMask;
9use crate::operators::options::GetOperatorOptions;
10
11use crate::value_type::ValueType;
12
13// TODO: explicitly define how dupicates are handled
14
15// Implemented methods do not provide mutable access to GraphBLAS operators or options.
16// Code review must consider that no mtable access is provided.
17// https://doc.rust-lang.org/nomicon/send-and-sync.html
18unsafe impl Send for InsertVectorIntoSubVectorOperator {}
19unsafe impl Sync for InsertVectorIntoSubVectorOperator {}
20
21#[derive(Debug, Clone)]
22pub struct InsertVectorIntoSubVectorOperator {}
23
24impl InsertVectorIntoSubVectorOperator {
25    pub fn new() -> Self {
26        Self {}
27    }
28}
29
30pub trait InsertVectorIntoSubVector<VectorToInsertInto>
31where
32    VectorToInsertInto: ValueType,
33{
34    /// mask and replace option apply to entire matrix_to_insert_to
35    fn apply(
36        &self,
37        vector_to_insert_into: &mut SparseVector<VectorToInsertInto>,
38        indices_to_insert_into: &ElementIndexSelector,
39        vector_to_insert: impl GetGraphblasSparseVector,
40        accumulator: &impl AccumulatorBinaryOperator<VectorToInsertInto>,
41        mask_for_vector_to_insert_into: &impl VectorMask,
42        options: &impl GetOperatorOptions,
43    ) -> Result<(), SparseLinearAlgebraError>;
44}
45
46impl<VectorToInsertInto: ValueType> InsertVectorIntoSubVector<VectorToInsertInto>
47    for InsertVectorIntoSubVectorOperator
48{
49    /// mask and replace option apply to entire matrix_to_insert_to
50    fn apply(
51        &self,
52        vector_to_insert_into: &mut SparseVector<VectorToInsertInto>,
53        indices_to_insert_into: &ElementIndexSelector,
54        vector_to_insert: impl GetGraphblasSparseVector,
55        accumulator: &impl AccumulatorBinaryOperator<VectorToInsertInto>,
56        mask_for_vector_to_insert_into: &impl VectorMask,
57        options: &impl GetOperatorOptions,
58    ) -> Result<(), SparseLinearAlgebraError> {
59        let context = vector_to_insert_into.context_ref();
60
61        let number_of_indices_to_insert_into = indices_to_insert_into
62            .number_of_selected_elements(vector_to_insert_into.length()?)?
63            .to_graphblas_index()?;
64
65        let indices_to_insert_into = indices_to_insert_into.to_graphblas_type()?;
66
67        match indices_to_insert_into {
68            ElementIndexSelectorGraphblasType::Index(index) => {
69                context.call(
70                    || unsafe {
71                        GxB_Vector_subassign(
72                            GetGraphblasSparseVector::graphblas_vector(vector_to_insert_into),
73                            mask_for_vector_to_insert_into.graphblas_vector(),
74                            accumulator.accumulator_graphblas_type(),
75                            vector_to_insert.graphblas_vector(),
76                            index.as_ptr(),
77                            number_of_indices_to_insert_into,
78                            options.graphblas_descriptor(),
79                        )
80                    },
81                    unsafe { vector_to_insert_into.graphblas_vector_ref() },
82                )?;
83            }
84
85            ElementIndexSelectorGraphblasType::All(index) => {
86                context.call(
87                    || unsafe {
88                        GxB_Vector_subassign(
89                            GetGraphblasSparseVector::graphblas_vector(vector_to_insert_into),
90                            mask_for_vector_to_insert_into.graphblas_vector(),
91                            accumulator.accumulator_graphblas_type(),
92                            vector_to_insert.graphblas_vector(),
93                            index,
94                            number_of_indices_to_insert_into,
95                            options.graphblas_descriptor(),
96                        )
97                    },
98                    unsafe { vector_to_insert_into.graphblas_vector_ref() },
99                )?;
100            }
101        }
102
103        Ok(())
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    use crate::collections::sparse_vector::operations::{
112        FromVectorElementList, GetSparseVectorElementValue,
113    };
114    use crate::collections::sparse_vector::VectorElementList;
115    use crate::collections::Collection;
116    use crate::context::Context;
117    use crate::index::ElementIndex;
118    use crate::operators::binary_operator::{Assignment, First};
119    use crate::operators::mask::SelectEntireVector;
120    use crate::operators::options::OperatorOptions;
121
122    #[test]
123    fn test_insert_vector_into_vector() {
124        let context = Context::init_default().unwrap();
125
126        let element_list = VectorElementList::<u8>::from_element_vector(vec![
127            (1, 1).into(),
128            (2, 2).into(),
129            (4, 10).into(),
130            (5, 12).into(),
131        ]);
132
133        let vector_length: usize = 10;
134        let mut vector = SparseVector::<u8>::from_element_list(
135            context.clone(),
136            vector_length,
137            element_list.clone(),
138            &First::<u8>::new(),
139        )
140        .unwrap();
141
142        let element_list_to_insert = VectorElementList::<u8>::from_element_vector(vec![
143            (1, 2).into(),
144            (2, 3).into(),
145            (4, 11).into(),
146            // (5, 11).into(),
147        ]);
148
149        let vector_to_insert_length: usize = 5;
150        let vector_to_insert = SparseVector::<u8>::from_element_list(
151            context.clone(),
152            vector_to_insert_length,
153            element_list_to_insert,
154            &First::<u8>::new(),
155        )
156        .unwrap();
157
158        let mask_element_list = VectorElementList::<bool>::from_element_vector(vec![
159            // (1, 1, true).into(),
160            (2, true).into(),
161            (4, true).into(),
162            // (5, true).into(),
163        ]);
164        let mask = SparseVector::<bool>::from_element_list(
165            context.clone(),
166            vector_to_insert_length,
167            mask_element_list,
168            &First::<bool>::new(),
169        )
170        .unwrap();
171
172        let indices_to_insert: Vec<ElementIndex> = (0..5).collect();
173        let indices_to_insert = ElementIndexSelector::Index(&indices_to_insert);
174
175        let insert_operator = InsertVectorIntoSubVectorOperator::new();
176
177        insert_operator
178            .apply(
179                &mut vector,
180                &indices_to_insert,
181                vector_to_insert.clone(),
182                &Assignment::new(),
183                &SelectEntireVector::new(context.clone()),
184                &OperatorOptions::new_default(),
185            )
186            .unwrap();
187
188        println!("{}", vector);
189
190        assert_eq!(vector.number_of_stored_elements().unwrap(), 4);
191        assert_eq!(vector.element_value(0).unwrap(), None);
192        assert_eq!(vector.element_value_or_default(2).unwrap(), 3);
193        assert_eq!(vector.element_value_or_default(4).unwrap(), 11);
194        assert_eq!(vector.element_value_or_default(5).unwrap(), 12);
195
196        let mut vector = SparseVector::<u8>::from_element_list(
197            context.clone(),
198            vector_length,
199            element_list,
200            &First::<u8>::new(),
201        )
202        .unwrap();
203
204        insert_operator
205            .apply(
206                &mut vector,
207                &indices_to_insert,
208                vector_to_insert,
209                &Assignment::new(),
210                &mask,
211                &OperatorOptions::new_default(),
212            )
213            .unwrap();
214
215        println!("{}", vector);
216
217        assert_eq!(vector.number_of_stored_elements().unwrap(), 4);
218        assert_eq!(vector.element_value(0).unwrap(), None);
219        assert_eq!(vector.element_value_or_default(2).unwrap(), 3);
220        assert_eq!(vector.element_value_or_default(4).unwrap(), 11);
221        assert_eq!(vector.element_value_or_default(5).unwrap(), 12);
222        assert_eq!(vector.element_value_or_default(1).unwrap(), 1);
223    }
224}