1use nalgebra::{DMatrix, Scalar};
6use paradis_core::{BoundedParAccess, Bounds, ParAccess};
7use std::marker::PhantomData;
8
9pub struct DMatrixParAccessMut<'a, T> {
10 ptr: *mut T,
11 rows: usize,
12 cols: usize,
13 marker: PhantomData<&'a T>,
14}
15
16unsafe impl<'a, T> Send for DMatrixParAccessMut<'a, T> {}
17unsafe impl<'a, T> Sync for DMatrixParAccessMut<'a, T> {}
18
19impl<'a, T> DMatrixParAccessMut<'a, T> {
20 pub fn from_matrix_mut(matrix: &'a mut DMatrix<T>) -> Self {
21 Self {
22 rows: matrix.nrows(),
23 cols: matrix.ncols(),
24 marker: Default::default(),
25 ptr: matrix.as_mut_ptr(),
26 }
27 }
28}
29
30unsafe impl<'a, T: Scalar + Send> ParAccess<(usize, usize)> for DMatrixParAccessMut<'a, T> {
31 type Record = &'a mut T;
32
33 unsafe fn clone_access(&self) -> Self {
34 Self {
35 ptr: self.ptr,
36 rows: self.rows,
37 cols: self.cols,
38 marker: self.marker,
39 }
40 }
41
42 unsafe fn get_unsync_unchecked(&self, (i, j): (usize, usize)) -> Self::Record {
43 let linear_idx = j * self.rows + i;
45 &mut *self.ptr.add(linear_idx)
46 }
47}
48
49unsafe impl<'a, T: Scalar + Send> BoundedParAccess<(usize, usize)> for DMatrixParAccessMut<'a, T> {
50 fn bounds(&self) -> Bounds<(usize, usize)> {
51 Bounds {
52 offset: (0, 0),
53 extent: (self.rows, self.cols),
54 }
55 }
56
57 fn in_bounds(&self, (i, j): (usize, usize)) -> bool {
58 i < self.rows && j < self.cols
59 }
60}