mop_structs/matrix/csr_matrix/
csr_matrix_row_iter_impls.rs

1use crate::{
2  dim::Dim,
3  vec::css::{CssSlice, CssSliceMut},
4};
5use core::slice::{from_raw_parts, from_raw_parts_mut};
6
7macro_rules! impl_iter (
8    (
9        $csr_matrix_row_iter:ident,
10        $data_ptr:ty,
11        $data_type:ty,
12        $from_raw_parts:ident,
13        $css_slice:ident
14    ) => (
15
16#[derive(Debug)]
17pub struct $csr_matrix_row_iter<'a, T: 'a> {
18    curr_row: usize,
19    data: $data_ptr,
20    dim: Dim<[usize; 2]>,
21    indcs: &'a [usize],
22    ptrs: &'a [usize],
23}
24
25impl<'a, T> $csr_matrix_row_iter<'a, T> {
26    pub(crate) fn new(
27        dim: [usize; 2],
28        data: $data_ptr,
29        indcs: &'a [usize],
30        ptrs: &'a [usize]
31    ) -> Self {
32        $csr_matrix_row_iter {
33            curr_row: 0,
34            data,
35            dim: dim.into(),
36            indcs,
37            ptrs,
38        }
39    }
40
41    pub fn split_at(self, idx: usize) -> (Self, Self) {
42        let current_len = self.dim.rows() - self.curr_row;
43        assert!(idx <= current_len);
44        let slice_point = self.curr_row + idx;
45        (
46            $csr_matrix_row_iter {
47                curr_row: self.curr_row,
48                data: self.data,
49                dim: [slice_point, self.dim.cols()].into(),
50                indcs: self.indcs,
51                ptrs: self.ptrs,
52            },
53            $csr_matrix_row_iter {
54                curr_row: slice_point,
55                data: self.data,
56                dim: self.dim,
57                indcs: self.indcs,
58                ptrs: self.ptrs,
59            },
60        )
61    }
62}
63
64impl<'a, T> DoubleEndedIterator for $csr_matrix_row_iter<'a, T> {
65    fn next_back(&mut self) -> Option<Self::Item> {
66        if self.curr_row >= self.dim.rows() {
67            return None;
68        }
69        let starting_row_ptr = self.ptrs[self.dim.rows() - 1];
70        let finishing_row_ptr = self.ptrs[self.dim.rows()];
71        let data: $data_type = unsafe {
72            let ptr = self.data.add(starting_row_ptr);
73            let len = finishing_row_ptr - starting_row_ptr;
74            $from_raw_parts(ptr, len)
75        };
76        *self.dim.rows_mut() -= 1;
77        Some($css_slice::new(
78            self.dim.cols(),
79            data,
80            &self.indcs[starting_row_ptr..finishing_row_ptr],
81        ))
82    }
83}
84
85impl<'a, T> ExactSizeIterator for $csr_matrix_row_iter<'a, T> {}
86
87impl<'a, T> Iterator for $csr_matrix_row_iter<'a, T> {
88    type Item = $css_slice<'a, T>;
89
90    fn next(&mut self) -> Option<Self::Item> {
91        if self.curr_row >= self.dim.rows() {
92            return None;
93        }
94
95        let starting_row_ptr = self.ptrs[self.curr_row];
96        let finishing_row_ptr = self.ptrs[self.curr_row + 1];
97
98        let data: $data_type = unsafe {
99            let ptr = self.data.add(starting_row_ptr);
100            $from_raw_parts(ptr, finishing_row_ptr - starting_row_ptr)
101        };
102
103        self.curr_row += 1;
104
105        Some($css_slice::new(
106            self.dim.cols(),
107            data,
108            &self.indcs[starting_row_ptr..finishing_row_ptr],
109        ))
110    }
111
112    fn size_hint(&self) -> (usize, Option<usize>) {
113        (self.dim.rows(), Some(self.dim.rows()))
114    }
115}
116
117unsafe impl<'a, T> Send for $csr_matrix_row_iter<'a, T> {}
118unsafe impl<'a, T> Sync for $csr_matrix_row_iter<'a, T> {}
119
120    );
121);
122
123impl_iter!(
124  CsrMatrixRowIter,
125  *const T,
126  &'a [T],
127  from_raw_parts,
128  CssSlice
129);
130impl_iter!(
131  CsrMatrixRowIterMut,
132  *mut T,
133  &'a mut [T],
134  from_raw_parts_mut,
135  CssSliceMut
136);