1use crate::coo::CooMatrix;
9use crate::csc::CscMatrix;
10use crate::csr::CsrMatrix;
11use crate::pattern::SparsityPattern;
12use nalgebra::proptest::DimRange;
13use nalgebra::{Dim, Scalar};
14use proptest::collection::{btree_set, hash_map, vec};
15use proptest::prelude::*;
16use proptest::sample::Index;
17use std::cmp::min;
18use std::iter::repeat;
19
20fn dense_row_major_coord_strategy(
21 nrows: usize,
22 ncols: usize,
23 nnz: usize,
24) -> impl Strategy<Value = Vec<(usize, usize)>> {
25 assert!(nnz <= nrows * ncols);
26 let mut booleans = vec![true; nnz];
27 booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
28 Just(booleans)
30 .prop_shuffle()
32 .prop_map(move |booleans| {
33 booleans
34 .into_iter()
35 .enumerate()
36 .filter_map(|(index, is_entry)| {
37 if is_entry {
38 let i = index / ncols;
40 let j = index % ncols;
41 Some((i, j))
42 } else {
43 None
44 }
45 })
46 .collect::<Vec<_>>()
47 })
48}
49
50fn dense_triplet_strategy<T>(
54 value_strategy: T,
55 nrows: usize,
56 ncols: usize,
57 nnz: usize,
58) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
59where
60 T: Strategy + Clone + 'static,
61 T::Value: Scalar,
62{
63 assert!(nnz <= nrows * ncols);
64
65 let booleans: Vec<_> = repeat(true)
67 .take(nnz)
68 .chain(repeat(false))
69 .take(nrows * ncols)
70 .collect();
71
72 Just(booleans)
73 .prop_shuffle()
75 .prop_map(move |booleans| {
77 booleans
78 .into_iter()
79 .enumerate()
80 .filter_map(|(index, is_entry)| {
81 if is_entry {
82 let i = index / ncols;
84 let j = index % ncols;
85 Some((i, j))
86 } else {
87 None
88 }
89 })
90 .collect::<Vec<_>>()
91 })
92 .prop_flat_map(move |coords| {
94 vec![value_strategy.clone(); coords.len()].prop_map(move |values| {
95 coords
96 .clone()
97 .into_iter()
98 .zip(values)
99 .map(|((i, j), v)| (i, j, v))
100 .collect::<Vec<_>>()
101 })
102 })
103}
104
105fn sparse_triplet_strategy<T>(
110 value_strategy: T,
111 nrows: usize,
112 ncols: usize,
113 nnz: usize,
114) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
115where
116 T: Strategy + Clone + 'static,
117 T::Value: Scalar,
118{
119 let row_index_strategy = if nrows > 0 { 0..nrows } else { 0..1 };
121 let col_index_strategy = if ncols > 0 { 0..ncols } else { 0..1 };
122 let coord_strategy = (row_index_strategy, col_index_strategy);
123 hash_map(coord_strategy, value_strategy.clone(), nnz)
124 .prop_map(|hash_map| {
125 let triplets: Vec<_> = hash_map.into_iter().map(|((i, j), v)| (i, j, v)).collect();
126 triplets
127 })
128 .prop_shuffle()
132}
133
134pub fn coo_no_duplicates<T>(
141 value_strategy: T,
142 rows: impl Into<DimRange>,
143 cols: impl Into<DimRange>,
144 max_nonzeros: usize,
145) -> impl Strategy<Value = CooMatrix<T::Value>>
146where
147 T: Strategy + Clone + 'static,
148 T::Value: Scalar,
149{
150 (
151 rows.into().to_range_inclusive(),
152 cols.into().to_range_inclusive(),
153 )
154 .prop_flat_map(move |(nrows, ncols)| {
155 let max_nonzeros = min(max_nonzeros, nrows * ncols);
156 let size_range = 0..=max_nonzeros;
157 let value_strategy = value_strategy.clone();
158
159 size_range
160 .prop_flat_map(move |nnz| {
161 let value_strategy = value_strategy.clone();
162 if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) {
163 dense_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
166 } else {
167 sparse_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
170 }
171 })
172 .prop_map(move |triplets| {
173 let mut coo = CooMatrix::new(nrows, ncols);
174 for (i, j, v) in triplets {
175 coo.push(i, j, v);
176 }
177 coo
178 })
179 })
180}
181
182pub fn coo_with_duplicates<T>(
194 value_strategy: T,
195 rows: impl Into<DimRange>,
196 cols: impl Into<DimRange>,
197 max_nonzeros: usize,
198 max_duplicates: usize,
199) -> impl Strategy<Value = CooMatrix<T::Value>>
200where
201 T: Strategy + Clone + 'static,
202 T::Value: Scalar,
203{
204 let coo_strategy = coo_no_duplicates(value_strategy.clone(), rows, cols, max_nonzeros);
205 let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0..=max_duplicates);
206 (coo_strategy, duplicate_strategy)
207 .prop_flat_map(|(coo, duplicates)| {
208 let mut triplets: Vec<(usize, usize, T::Value)> = coo
209 .triplet_iter()
210 .map(|(i, j, v)| (i, j, v.clone()))
211 .collect();
212 if !triplets.is_empty() {
213 let duplicates_iter: Vec<_> = duplicates
214 .into_iter()
215 .map(|(idx, val)| {
216 let (i, j, _) = idx.get(&triplets);
217 (*i, *j, val)
218 })
219 .collect();
220 triplets.extend(duplicates_iter);
221 }
222 let shuffled = Just(triplets).prop_shuffle();
224 (Just(coo.nrows()), Just(coo.ncols()), shuffled)
225 })
226 .prop_map(move |(nrows, ncols, triplets)| {
227 let mut coo = CooMatrix::new(nrows, ncols);
228 for (i, j, v) in triplets {
229 coo.push(i, j, v);
230 }
231 coo
232 })
233}
234
235fn sparsity_pattern_from_row_major_coords<I>(
236 nmajor: usize,
237 nminor: usize,
238 coords: I,
239) -> SparsityPattern
240where
241 I: Iterator<Item = (usize, usize)> + ExactSizeIterator,
242{
243 let mut minors = Vec::with_capacity(coords.len());
244 let mut offsets = Vec::with_capacity(nmajor + 1);
245 let mut current_major = 0;
246 offsets.push(0);
247 for (idx, (i, j)) in coords.enumerate() {
248 assert!(i >= current_major);
249 assert!(
250 i < nmajor && j < nminor,
251 "Generated coords are out of bounds"
252 );
253 while current_major < i {
254 offsets.push(idx);
255 current_major += 1;
256 }
257 minors.push(j);
258 }
259
260 while current_major < nmajor {
261 offsets.push(minors.len());
262 current_major += 1;
263 }
264
265 assert_eq!(offsets.first().unwrap(), &0);
266 assert_eq!(offsets.len(), nmajor + 1);
267
268 SparsityPattern::try_from_offsets_and_indices(nmajor, nminor, offsets, minors)
269 .expect("Internal error: Generated sparsity pattern is invalid")
270}
271
272pub fn sparsity_pattern(
274 major_lanes: impl Into<DimRange>,
275 minor_lanes: impl Into<DimRange>,
276 max_nonzeros: usize,
277) -> impl Strategy<Value = SparsityPattern> {
278 (
279 major_lanes.into().to_range_inclusive(),
280 minor_lanes.into().to_range_inclusive(),
281 )
282 .prop_flat_map(move |(nmajor, nminor)| {
283 let max_nonzeros = min(nmajor * nminor, max_nonzeros);
284 (Just(nmajor), Just(nminor), 0..=max_nonzeros)
285 })
286 .prop_flat_map(move |(nmajor, nminor, nnz)| {
287 if 10 * nnz < nmajor * nminor {
288 btree_set((0..nmajor, 0..nminor), nnz)
290 .prop_map(move |coords| {
291 sparsity_pattern_from_row_major_coords(nmajor, nminor, coords.into_iter())
292 })
293 .boxed()
294 } else {
295 dense_row_major_coord_strategy(nmajor, nminor, nnz)
298 .prop_map(move |coords| {
299 let coords = coords.into_iter();
300 sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
301 })
302 .boxed()
303 }
304 })
305}
306
307pub fn csr<T>(
309 value_strategy: T,
310 rows: impl Into<DimRange>,
311 cols: impl Into<DimRange>,
312 max_nonzeros: usize,
313) -> impl Strategy<Value = CsrMatrix<T::Value>>
314where
315 T: Strategy + Clone + 'static,
316 T::Value: Scalar,
317{
318 let rows = rows.into();
319 let cols = cols.into();
320 sparsity_pattern(
321 rows.lower_bound().value()..=rows.upper_bound().value(),
322 cols.lower_bound().value()..=cols.upper_bound().value(),
323 max_nonzeros,
324 )
325 .prop_flat_map(move |pattern| {
326 let nnz = pattern.nnz();
327 let values = vec![value_strategy.clone(); nnz];
328 (Just(pattern), values)
329 })
330 .prop_map(|(pattern, values)| {
331 CsrMatrix::try_from_pattern_and_values(pattern, values)
332 .expect("Internal error: Generated CsrMatrix is invalid")
333 })
334}
335
336pub fn csc<T>(
338 value_strategy: T,
339 rows: impl Into<DimRange>,
340 cols: impl Into<DimRange>,
341 max_nonzeros: usize,
342) -> impl Strategy<Value = CscMatrix<T::Value>>
343where
344 T: Strategy + Clone + 'static,
345 T::Value: Scalar,
346{
347 let rows = rows.into();
348 let cols = cols.into();
349 sparsity_pattern(
350 cols.lower_bound().value()..=cols.upper_bound().value(),
351 rows.lower_bound().value()..=rows.upper_bound().value(),
352 max_nonzeros,
353 )
354 .prop_flat_map(move |pattern| {
355 let nnz = pattern.nnz();
356 let values = vec![value_strategy.clone(); nnz];
357 (Just(pattern), values)
358 })
359 .prop_map(|(pattern, values)| {
360 CscMatrix::try_from_pattern_and_values(pattern, values)
361 .expect("Internal error: Generated CscMatrix is invalid")
362 })
363}