1use matrixcompare_core::{Access, DenseAccess, Matrix, SparseAccess};
6use proptest::prelude::*;
7use std::fmt::Debug;
8
9use num::Zero;
10use std::ops::Range;
11
12#[derive(Clone, Debug)]
13pub struct MockDenseMatrix<T> {
14 data: Vec<T>,
15 rows: usize,
16 cols: usize,
17}
18
19#[derive(Clone, Debug)]
20pub struct MockSparseMatrix<T> {
21 shape: (usize, usize),
22 triplets: Vec<(usize, usize, T)>,
23}
24
25impl<T> MockSparseMatrix<T> {
26 pub fn from_triplets(rows: usize, cols: usize, triplets: Vec<(usize, usize, T)>) -> Self {
27 Self {
28 shape: (rows, cols),
29 triplets,
30 }
31 }
32
33 pub fn take_triplets(self) -> Vec<(usize, usize, T)> {
34 self.triplets
35 }
36}
37
38impl<T> MockSparseMatrix<T>
39where
40 T: Zero + Clone,
41{
42 pub fn to_dense(&self) -> Result<MockDenseMatrix<T>, ()> {
43 let (r, c) = (self.rows(), self.cols());
44 let mut result =
45 MockDenseMatrix::from_row_major(self.rows(), self.cols(), vec![T::zero(); r * c]);
46 for (i, j, v) in &self.triplets {
47 *result.get_mut(*i, *j).ok_or(())? = v.clone();
48 }
49
50 Ok(result)
51 }
52}
53
54impl<T> MockDenseMatrix<T> {
55 pub fn from_row_major(rows: usize, cols: usize, data: Vec<T>) -> Self {
56 assert_eq!(
57 rows * cols,
58 data.len(),
59 "Data must have rows*cols number of elements."
60 );
61 Self { data, rows, cols }
62 }
63
64 fn get_linear_index(&self, i: usize, j: usize) -> Option<usize> {
65 if i < self.rows && j < self.cols {
66 Some(i * self.cols + j)
67 } else {
68 None
69 }
70 }
71
72 pub fn get(&self, i: usize, j: usize) -> Option<&T> {
73 self.get_linear_index(i, j).map(|idx| &self.data[idx])
74 }
75
76 pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut T> {
77 self.get_linear_index(i, j)
78 .map(move |idx| &mut self.data[idx])
79 }
80}
81
82impl<T: Clone> Matrix<T> for MockDenseMatrix<T> {
83 fn rows(&self) -> usize {
84 self.rows
85 }
86
87 fn cols(&self) -> usize {
88 self.cols
89 }
90
91 fn access(&self) -> Access<T> {
92 Access::Dense(self)
93 }
94}
95
96impl<T: Clone> DenseAccess<T> for MockDenseMatrix<T> {
97 fn fetch_single(&self, row: usize, col: usize) -> T {
98 let idx = row * self.cols + col;
99 self.data[idx].clone()
100 }
101}
102
103impl<T: Clone> Matrix<T> for MockSparseMatrix<T> {
104 fn rows(&self) -> usize {
105 self.shape.0
106 }
107
108 fn cols(&self) -> usize {
109 self.shape.1
110 }
111
112 fn access(&self) -> Access<T> {
113 Access::Sparse(self)
114 }
115}
116
117impl<T: Clone> SparseAccess<T> for MockSparseMatrix<T> {
118 fn nnz(&self) -> usize {
119 self.triplets.len()
120 }
121
122 fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
123 self.triplets.clone()
124 }
125}
126
127#[macro_export]
131macro_rules! mock_matrix {
132 () => {
133 {
134 use $crate::MockDenseMatrix;
136 MockDenseMatrix::from_row_major(0, 0, vec![])
137 }
138 };
139 ($( $( $x: expr ),*);*) => {
140 {
141 use $crate::MockDenseMatrix;
142 let data_as_nested_array = [ $( [ $($x),* ] ),* ];
143 let rows = data_as_nested_array.len();
144 let cols = data_as_nested_array[0].len();
145 let data_as_flat_array: Vec<_> = data_as_nested_array.iter()
146 .flat_map(|row| row.into_iter())
147 .cloned()
148 .collect();
149 MockDenseMatrix::from_row_major(rows, cols, data_as_flat_array)
150 }
151 }
152}
153
154pub fn i64_range() -> Range<i64> {
155 -100i64 .. 100
156}
157
158pub fn dense_matrix_strategy<T, S>(
159 rows: impl Strategy<Value = usize>,
160 cols: impl Strategy<Value = usize>,
161 strategy: S,
162) -> impl Strategy<Value = MockDenseMatrix<T>>
163where
164 T: Debug,
165 S: Clone + Strategy<Value = T>,
166{
167 (rows, cols).prop_flat_map(move |(r, c)| {
168 proptest::collection::vec(strategy.clone(), r * c)
169 .prop_map(move |data| MockDenseMatrix::from_row_major(r, c, data))
170 })
171}
172
173pub fn dense_matrix_strategy_i64(
174 rows: impl Strategy<Value = usize>,
175 cols: impl Strategy<Value = usize>,
176) -> impl Strategy<Value = MockDenseMatrix<i64>> {
177 dense_matrix_strategy(rows, cols, i64_range())
178}
179
180pub fn dense_matrix_strategy_normal_f64(
182 rows: impl Strategy<Value = usize>,
183 cols: impl Strategy<Value = usize>,
184) -> impl Strategy<Value = MockDenseMatrix<f64>> {
185 dense_matrix_strategy(rows, cols, proptest::num::f64::NORMAL)
186}
187
188pub fn sparse_matrix_strategy<T, S>(
189 rows: impl Strategy<Value = usize>,
190 cols: impl Strategy<Value = usize>,
191 strategy: S,
192) -> impl Strategy<Value = MockSparseMatrix<T>>
193where
194 T: Debug,
195 S: Clone + Strategy<Value = T>,
196{
197 (rows, cols).prop_flat_map(move |(r, c)| {
200 let max_nnz = r * c;
201 let ij_strategy = (0..r, 0..c);
202 let values_strategy = strategy.clone();
203 proptest::collection::btree_map(ij_strategy, values_strategy, 0..=max_nnz)
205 .prop_map(|map_matrix| map_matrix
206 .into_iter()
207 .map(|((i, j), v)| (i, j, v))
208 .collect())
209 .prop_map(move |triplets| MockSparseMatrix::from_triplets(r, c, triplets))
210 })
211}
212
213pub fn sparse_matrix_strategy_i64(
214 rows: impl Strategy<Value = usize>,
215 cols: impl Strategy<Value = usize>,
216) -> impl Strategy<Value = MockSparseMatrix<i64>> {
217 sparse_matrix_strategy(rows, cols, i64_range())
218}
219
220pub fn sparse_matrix_strategy_normal_f64(
221 rows: impl Strategy<Value = usize>,
222 cols: impl Strategy<Value = usize>,
223) -> impl Strategy<Value = MockSparseMatrix<f64>> {
224 sparse_matrix_strategy(rows, cols, proptest::num::f64::NORMAL)
225}