Skip to main content

graphblas_sparse_linear_algebra/collections/sparse_scalar/
sparse_scalar.rs

1use std::marker::PhantomData;
2use std::mem::MaybeUninit;
3use std::sync::Arc;
4
5use suitesparse_graphblas_sys::{GrB_Info, GrB_Type};
6
7use crate::collections::collection::Collection;
8use crate::collections::sparse_scalar::operations::SetScalarValue;
9use crate::context::{CallGraphBlasContext, Context, GetContext};
10use crate::error::{
11    GraphblasErrorType, LogicErrorType, SparseLinearAlgebraError, SparseLinearAlgebraErrorType,
12};
13use crate::graphblas_bindings::{
14    GrB_Index, GrB_Scalar, GrB_Scalar_clear, GrB_Scalar_dup, GrB_Scalar_free, GrB_Scalar_new,
15    GrB_Scalar_nvals,
16};
17use crate::index::{ElementCount, ElementIndex, IndexConversion};
18use crate::value_type::utilities_to_implement_traits_for_all_value_types::{
19    implement_1_type_macro_for_all_value_types_and_typed_graphblas_function_with_implementation_type,
20    implement_macro_for_all_value_types,
21};
22use crate::value_type::ConvertScalar;
23use crate::value_type::ValueType;
24
25use crate::collections::sparse_scalar::operations::GetScalarValue;
26
27#[derive(Debug)]
28pub struct SparseScalar<T: ValueType> {
29    context: Arc<Context>,
30    scalar: GrB_Scalar,
31    value_type: PhantomData<T>,
32}
33
34// Mutable access to GrB_Vector shall occur through a write lock on RwLock<GrB_Matrix>.
35// Code review must consider that the correct lock is made via
36// SparseMatrix::get_write_lock() and SparseMatrix::get_read_lock().
37// https://doc.rust-lang.org/nomicon/send-and-sync.html
38unsafe impl<T: ValueType> Send for SparseScalar<T> {}
39unsafe impl<T: ValueType> Sync for SparseScalar<T> {}
40
41pub unsafe fn new_graphblas_scalar(
42    context: &Arc<Context>,
43    graphblas_value_type: GrB_Type,
44) -> Result<GrB_Scalar, SparseLinearAlgebraError> {
45    let mut scalar: MaybeUninit<GrB_Scalar> = MaybeUninit::uninit();
46
47    context.call_without_detailed_error_information(|| unsafe {
48        GrB_Scalar_new(scalar.as_mut_ptr(), graphblas_value_type)
49    })?;
50
51    let scalar = unsafe { scalar.assume_init() };
52    return Ok(scalar);
53}
54
55impl<T: ValueType> SparseScalar<T> {
56    pub fn new(context: Arc<Context>) -> Result<Self, SparseLinearAlgebraError> {
57        let scalar = unsafe { new_graphblas_scalar(&context, T::to_graphblas_type())? };
58        return Ok(SparseScalar {
59            context,
60            scalar,
61            value_type: PhantomData,
62        });
63    }
64
65    pub unsafe fn from_graphblas_scalar(
66        context: Arc<Context>,
67        scalar: GrB_Scalar,
68    ) -> Result<SparseScalar<T>, SparseLinearAlgebraError> {
69        Ok(SparseScalar {
70            context: context.clone(),
71            scalar,
72            value_type: PhantomData,
73        })
74    }
75
76    // pub fn from_value(
77    //     context: &Arc<Context>,
78    //     value: &T,
79    // ) -> Result<Self, SparseLinearAlgebraError> {
80    //     let mut sparse_scalar = SparseScalar::new(context)?;
81    //     sparse_scalar.set_value(value)?;
82    //     Ok(sparse_scalar)
83    // }
84}
85
86// impl<T: ValueType + BuiltInValueType + SetScalarValue<T>> SparseScalar<T> {
87//     pub fn from_scalar(context: &Arc<Context>, value: T) -> Result<Self, SparseLinearAlgebraError> {
88//         let mut sparse_scalar = SparseScalar::new(context)?;
89//         sparse_scalar.set_value(&value)?;
90//         Ok(sparse_scalar)
91//     }
92// }
93
94macro_rules! sparse_scalar_from_scalar {
95    ($value_type: ty) => {
96        impl SparseScalar<$value_type> {
97            pub fn from_value(
98                context: Arc<Context>,
99                value: $value_type,
100            ) -> Result<Self, SparseLinearAlgebraError> {
101                let mut sparse_scalar = SparseScalar::new(context)?;
102                sparse_scalar.set_value(value)?;
103                Ok(sparse_scalar)
104            }
105        }
106    };
107}
108implement_macro_for_all_value_types!(sparse_scalar_from_scalar);
109
110// impl<T: ValueType + CustomValueType> SparseScalar<T> {
111//     pub fn new_custom_type(
112//         value_type: Arc<RegisteredCustomValueType<T>>,
113//     ) -> Result<Self, SparseLinearAlgebraError> {
114//         let mut scalar: MaybeUninit<GxB_Scalar> = MaybeUninit::uninit();
115//         let context = value_type.context();
116
117//         context.call(|| unsafe {
118//             GxB_Scalar_new(scalar.as_mut_ptr(), value_type.to_graphblas_type())
119//         })?;
120
121//         let scalar = unsafe { scalar.assume_init() };
122//         return Ok(SparseScalar {
123//             context,
124//             scalar,
125//             value_type: PhantomData,
126//         });
127//     }
128// }
129
130impl<T: ValueType> GetContext for SparseScalar<T> {
131    fn context(&self) -> Arc<Context> {
132        self.context.clone()
133    }
134
135    fn context_ref(&self) -> &Arc<Context> {
136        &self.context
137    }
138}
139
140impl<T: ValueType> Collection for SparseScalar<T> {
141    fn clear(&mut self) -> Result<(), SparseLinearAlgebraError> {
142        self.context
143            .call_without_detailed_error_information(|| unsafe { GrB_Scalar_clear(self.scalar) })?;
144        Ok(())
145    }
146
147    fn number_of_stored_elements(&self) -> Result<ElementCount, SparseLinearAlgebraError> {
148        let mut number_of_values: MaybeUninit<GrB_Index> = MaybeUninit::uninit();
149        self.context.call(
150            || unsafe { GrB_Scalar_nvals(number_of_values.as_mut_ptr(), self.scalar) },
151            &self.scalar,
152        )?;
153        let number_of_values = unsafe { number_of_values.assume_init() };
154        Ok(ElementIndex::from_graphblas_index(number_of_values)?)
155    }
156}
157
158impl<T: ValueType> Drop for SparseScalar<T> {
159    fn drop(&mut self) -> () {
160        let _ = self
161            .context
162            .call_without_detailed_error_information(|| unsafe {
163                GrB_Scalar_free(&mut self.scalar)
164            });
165    }
166}
167
168impl<T: ValueType> Clone for SparseScalar<T> {
169    fn clone(&self) -> Self {
170        let mut scalar_copy: MaybeUninit<GrB_Scalar> = MaybeUninit::uninit();
171        self.context
172            .call(
173                || unsafe { GrB_Scalar_dup(scalar_copy.as_mut_ptr(), self.scalar) },
174                &self.scalar,
175            )
176            .unwrap();
177
178        SparseScalar {
179            context: self.context.clone(),
180            scalar: unsafe { scalar_copy.assume_init() },
181            value_type: PhantomData,
182        }
183    }
184}
185
186pub trait GetGraphblasSparseScalar: GetContext {
187    unsafe fn graphblas_scalar(&self) -> GrB_Scalar;
188    unsafe fn graphblas_scalar_ref(&self) -> &GrB_Scalar;
189    unsafe fn graphblas_scalar_mut_ref(&mut self) -> &mut GrB_Scalar;
190}
191
192impl<T: ValueType> GetGraphblasSparseScalar for SparseScalar<T> {
193    unsafe fn graphblas_scalar(&self) -> GrB_Scalar {
194        self.scalar
195    }
196    unsafe fn graphblas_scalar_ref(&self) -> &GrB_Scalar {
197        &self.scalar
198    }
199    unsafe fn graphblas_scalar_mut_ref(&mut self) -> &mut GrB_Scalar {
200        &mut self.scalar
201    }
202}
203
204// // TODO improve printing format
205// // summary data, column aligning
206// impl<T: ValueType + GetScalarValue<T> + Default> std::fmt::Display for SparseScalar<T> {
207//     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
208//         let value: T;
209//         match self.get_value() {
210//             Err(_error) => return Err(std::fmt::Error),
211//             Ok(inner_value) => {
212//                 value = inner_value;
213//             }
214//         }
215//         writeln! {f,"Number of stored elements: {:?}", self.number_of_stored_elements()?};
216//         writeln! {f,"Value: {:?}", value};
217//         writeln!(f, "")
218//     }
219// }
220
221// TODO: make the implementation generic
222// TODO improve printing format
223// summary data, column aligning
224macro_rules! implement_dispay {
225    ($value_type:ty) => {
226        impl std::fmt::Display for SparseScalar<$value_type> {
227            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
228                let value: Option<$value_type>;
229                match self.value() {
230                    Err(_error) => return Err(std::fmt::Error),
231                    Ok(inner_value) => {
232                        value = inner_value;
233                    }
234                }
235                writeln! {f,"Number of stored elements: {:?}", self.number_of_stored_elements()?};
236                writeln! {f,"Value: {:?}", value};
237                writeln!(f, "")
238            }
239        }
240    };
241}
242implement_macro_for_all_value_types!(implement_dispay);
243
244#[cfg(test)]
245mod tests {
246
247    // #[macro_use(implement_value_type_for_custom_type)]
248
249    use super::*;
250
251    // use crate::value_type::{GraphblasFloat32, GraphblasInt32};
252
253    #[test]
254    fn new_scalar() {
255        let context = Context::init_default().unwrap();
256
257        let sparse_scalar = SparseScalar::<i32>::new(context).unwrap();
258
259        assert_eq!(0, sparse_scalar.number_of_stored_elements().unwrap());
260    }
261
262    #[test]
263    fn clone_scalar() {
264        let context = Context::init_default().unwrap();
265
266        let sparse_scalar = SparseScalar::<f32>::new(context).unwrap();
267
268        let clone_of_sparse_scalar = sparse_scalar.clone();
269
270        // TODO: implement and test equality operator
271        assert_eq!(
272            0,
273            clone_of_sparse_scalar.number_of_stored_elements().unwrap()
274        );
275    }
276
277    #[test]
278    fn test_set_value() {
279        let context = Context::init_default().unwrap();
280
281        let mut sparse_scalar = SparseScalar::<i32>::new(context).unwrap();
282
283        sparse_scalar.set_value(2).unwrap();
284
285        assert_eq!(1, sparse_scalar.number_of_stored_elements().unwrap());
286    }
287
288    #[test]
289    fn clear_value_from_scalar() {
290        let context = Context::init_default().unwrap();
291
292        let mut sparse_scalar = SparseScalar::<i32>::new(context).unwrap();
293
294        sparse_scalar.set_value(2).unwrap();
295
296        assert_eq!(1, sparse_scalar.number_of_stored_elements().unwrap());
297
298        assert_eq!(2, sparse_scalar.value_or_default().unwrap());
299
300        sparse_scalar.clear().unwrap();
301
302        assert_eq!(sparse_scalar.number_of_stored_elements().unwrap(), 0)
303    }
304
305    #[test]
306    fn test_get_value() {
307        let context = Context::init_default().unwrap();
308
309        let mut sparse_scalar = SparseScalar::<i32>::new(context).unwrap();
310
311        sparse_scalar.set_value(2).unwrap();
312
313        assert_eq!(1, sparse_scalar.number_of_stored_elements().unwrap());
314
315        assert_eq!(2, sparse_scalar.value_or_default().unwrap());
316    }
317}