fenris_sparse/
sparse.rs

1//! Functionality for sparse linear algebra.
2//!
3//! Some of it is intended to be ported to `nalgebra-sparse` later.
4use fenris_paradis::{ParallelIndexedAccess, ParallelIndexedCollection};
5use nalgebra_sparse::pattern::SparsityPattern;
6use nalgebra_sparse::CsrMatrix;
7use std::slice;
8
9// // TODO: Do we want to try to remove duplicates? Probably not...
10// pub fn from_offsets_and_unsorted_indices(
11//     major_dim: usize,
12//     minor_dim: usize,
13//     major_offsets: Vec<usize>,
14//     mut minor_indices: Vec<usize>,
15// ) -> Self {
16//     assert_eq!(major_offsets.len(), major_dim + 1);
17//     assert_eq!(*major_offsets.last().unwrap(), minor_indices.len());
18//     if major_offsets
19//         .iter()
20//         .tuple_windows()
21//         .any(|(prev, next)| prev > next)
22//     {
23//         panic!("Offsets must be non-decreasing.");
24//     }
25//
26//     for (major_begin, major_end) in major_offsets.iter().tuple_windows() {
27//         let minor = &mut minor_indices[*major_begin..*major_end];
28//         minor.sort_unstable();
29//         if minor
30//             .iter()
31//             .tuple_windows()
32//             .any(|(prev, next)| prev >= next)
33//         {
34//             panic!("Minor indices contain duplicates");
35//         }
36//     }
37//
38//     Self {
39//         major_offsets,
40//         minor_indices,
41//         minor_dim,
42//     }
43// }
44
45// /// Appends another sparsity pattern to this one, in the sense that it is extended
46// /// along its major dimension.
47// ///
48// /// Panics if `self` and `other` have different minor dimensions.
49// pub fn append_pattern(&mut self, other: &SparsityPattern) {
50//     assert_eq!(self.minor_dim(), other.minor_dim());
51//
52//     let offset_begin = *self.major_offsets.last().unwrap();
53//     let new_offsets_iter = other
54//         .major_offsets()
55//         .iter()
56//         .map(|offset| offset + offset_begin);
57//
58//     self.major_offsets.pop();
59//     self.major_offsets.extend(new_offsets_iter);
60//     self.minor_indices.extend_from_slice(&other.minor_indices);
61// }
62
63// // TODO: Write tests
64// pub fn diag_iter<'a>(&'a self) -> impl 'a + Iterator<Item = T>
65// where
66//     T: Zero + Clone,
67// {
68//     let ia = self.row_offsets();
69//     let ja = self.column_indices();
70//     (0..self.nrows()).map(move |i| {
71//         let row_begin = ia[i];
72//         let row_end = ia[i + 1];
73//         let columns_in_row = &ja[row_begin..row_end];
74//         if let Ok(idx) = columns_in_row.binary_search(&i) {
75//             self.values()[row_begin + idx].clone()
76//         } else {
77//             T::zero()
78//         }
79//     })
80// }
81
82// pub fn from_diagonal<'a>(diagonal: impl Into<DVectorView<'a, T>>) -> Self
83// where
84//     T: Scalar,
85// {
86//     let diagonal = diagonal.into();
87//     let vals = diagonal.iter().cloned().collect();
88//     let num_rows = diagonal.len();
89//     let ia = (0..(num_rows + 1)).collect();
90//     let ja = (0..num_rows).collect();
91//     Self::from_csr_data(num_rows, num_rows, ia, ja, vals)
92// }
93//
94// pub fn from_pattern_and_values(pattern: Arc<SparsityPattern>, values: Vec<T>) -> Self {
95//     assert_eq!(pattern.nnz(), values.len());
96//     Self {
97//         sparsity_pattern: pattern,
98//         v: values,
99//     }
100// }
101
102// /// Computes `self += a*x` where `x` is another matrix. Panics if the matrices are of different size.
103// pub fn add_assign_scaled(&mut self, a: T, x: &Self)
104// where
105//     T: Clone + ClosedAdd + ClosedMul,
106// {
107//     assert_eq!(self.values_mut().len(), x.values().len());
108//     for (v_i, x_i) in self.values_mut().iter_mut().zip(x.values().iter()) {
109//         *v_i += a.clone() * x_i.clone();
110//     }
111// }
112
113// pub fn append_csr_rows(&mut self, other: &CsrMatrix<T>)
114// where
115//     T: Clone,
116// {
117//     Arc::make_mut(&mut self.sparsity_pattern).append_pattern(&other.sparsity_pattern());
118//     self.v.extend_from_slice(other.values());
119// }
120
121// pub fn concat_diagonally(matrices: &[CsrMatrix<T>]) -> CsrMatrix<T>
122// where
123//     T: Clone,
124// {
125//     let mut num_rows = 0;
126//     let mut num_cols = 0;
127//     let mut nnz = 0;
128//
129//     // This first pass over the matrices is cheap, since we don't access any of the data.
130//     // We use this to be able to pre-allocate enough capacity so that no further
131//     // reallocation will be necessary.
132//     for matrix in matrices {
133//         num_rows += matrix.nrows();
134//         num_cols += matrix.ncols();
135//         nnz += matrix.nnz();
136//     }
137//
138//     let mut values = Vec::with_capacity(nnz);
139//     let mut column_indices = Vec::with_capacity(nnz);
140//     let mut row_offsets = Vec::with_capacity(num_rows + 1);
141//
142//     let mut col_offset = 0;
143//     let mut nnz_offset = 0;
144//     for matrix in matrices {
145//         values.extend_from_slice(matrix.values());
146//         column_indices.extend(matrix.column_indices().iter().map(|i| *i + col_offset));
147//         row_offsets.extend(
148//             matrix
149//                 .row_offsets()
150//                 .iter()
151//                 .take(matrix.nrows())
152//                 .map(|offset| *offset + nnz_offset),
153//         );
154//
155//         col_offset += matrix.ncols();
156//         nnz_offset += matrix.nnz();
157//     }
158//
159//     row_offsets.push(nnz);
160//
161//     Self {
162//         // TODO: Avoid validation of pattern for performance
163//         sparsity_pattern: Arc::new(SparsityPattern::from_offsets_and_indices(
164//             num_rows,
165//             num_cols,
166//             row_offsets,
167//             column_indices,
168//         )),
169//         v: values,
170//     }
171// }
172
173pub struct ParCsrRow<'a, T> {
174    column_indices: &'a [usize],
175    values: &'a [T],
176}
177
178pub struct ParCsrRowMut<'a, T> {
179    column_indices: &'a [usize],
180    values: *mut T,
181}
182
183impl<'a, T> ParCsrRow<'a, T> {
184    /// Number of non-zeros in this row.
185    pub fn nnz(&self) -> usize {
186        self.column_indices.len()
187    }
188
189    pub fn values(&self) -> &[T] {
190        self.values
191    }
192
193    pub fn col_indices(&self) -> &[usize] {
194        self.column_indices
195    }
196}
197
198impl<'a, T> ParCsrRowMut<'a, T> {
199    /// Number of non-zeros in this row.
200    pub fn nnz(&self) -> usize {
201        self.column_indices.len()
202    }
203
204    pub fn values_mut(&mut self) -> &mut [T] {
205        unsafe { slice::from_raw_parts_mut(self.values, self.column_indices.len()) }
206    }
207
208    pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
209        let values_mut = unsafe { slice::from_raw_parts_mut(self.values, self.column_indices.len()) };
210        (&self.column_indices, values_mut)
211    }
212}
213
214/// Wrapper for a CsrMatrix that allows it to be interpreted as a parallel collection of rows.
215pub struct ParallelCsrRowCollection<'a, T>(pub &'a mut CsrMatrix<T>);
216
217#[derive(Copy)]
218pub struct CsrParallelRowAccess<'a, T> {
219    pattern: &'a SparsityPattern,
220    values_ptr: *mut T,
221}
222
223impl<'a, T> Clone for CsrParallelRowAccess<'a, T> {
224    fn clone(&self) -> Self {
225        Self {
226            pattern: self.pattern,
227            values_ptr: self.values_ptr,
228        }
229    }
230}
231
232unsafe impl<'a, T: 'a + Sync> Sync for CsrParallelRowAccess<'a, T> {}
233unsafe impl<'a, T: 'a + Send> Send for CsrParallelRowAccess<'a, T> {}
234
235unsafe impl<'a, 'b, T: 'a + Sync + Send> ParallelIndexedAccess<'b> for CsrParallelRowAccess<'a, T>
236where
237    'a: 'b,
238{
239    type Record = ParCsrRow<'b, T>;
240    type RecordMut = ParCsrRowMut<'b, T>;
241
242    unsafe fn get_unchecked(&self, global_index: usize) -> Self::Record {
243        let major_offsets = self.pattern.major_offsets();
244        let row_begin = *major_offsets.get_unchecked(global_index);
245        let row_end = *major_offsets.get_unchecked(global_index + 1);
246        let column_indices = &self.pattern.minor_indices()[row_begin..row_end];
247        let values_ptr = self.values_ptr.add(row_begin);
248        let values = slice::from_raw_parts(values_ptr, column_indices.len());
249        ParCsrRow { column_indices, values }
250    }
251
252    unsafe fn get_unchecked_mut(&self, global_index: usize) -> Self::RecordMut {
253        let major_offsets = self.pattern.major_offsets();
254        let row_begin = *major_offsets.get_unchecked(global_index);
255        let row_end = *major_offsets.get_unchecked(global_index + 1);
256        let column_indices = &self.pattern.minor_indices()[row_begin..row_end];
257        let values_ptr = self.values_ptr.add(row_begin);
258        ParCsrRowMut {
259            column_indices,
260            values: values_ptr,
261        }
262    }
263}
264
265unsafe impl<'a, T: 'a + Sync + Send> ParallelIndexedCollection<'a> for ParallelCsrRowCollection<'a, T> {
266    type Access = CsrParallelRowAccess<'a, T>;
267
268    unsafe fn create_access(&'a mut self) -> Self::Access {
269        // TODO: Instead of storing a reference to the sparsity pattern we should probably
270        // rather store the CSR data directly
271        let values_ptr = self.0.values_mut().as_mut_ptr();
272        let pattern = self.0.pattern();
273        CsrParallelRowAccess { pattern, values_ptr }
274    }
275
276    fn len(&self) -> usize {
277        self.0.nrows()
278    }
279}
280
281// impl<T> CsrMatrix<T>
282// where
283//     T: Real,
284// {
285//     pub fn scale_rows<'a>(&mut self, diagonal_matrix: impl Into<DVectorView<'a, T>>) {
286//         let diag = diagonal_matrix.into();
287//         assert_eq!(diag.len(), self.nrows());
288//         self.transform_values(|i, _, v| *v *= diag[i]);
289//     }
290//
291//     pub fn scale_cols<'a>(&mut self, diagonal_matrix: impl Into<DVectorView<'a, T>>) {
292//         let diag = diagonal_matrix.into();
293//         assert_eq!(diag.len(), self.ncols());
294//         self.transform_values(|_, j, v| *v *= diag[j]);
295//     }
296// }