1use core::ops::Deref;
2
3use p3_field::PackedValue;
4
5use crate::Matrix;
6use crate::dense::RowMajorMatrix;
7
8pub trait RowIndexMap: Send + Sync {
13 fn height(&self) -> usize;
15
16 fn map_row_index(&self, r: usize) -> usize;
23
24 fn to_row_major_matrix<T: Clone + Send + Sync, Inner: Matrix<T>>(
29 &self,
30 inner: Inner,
31 ) -> RowMajorMatrix<T> {
32 RowMajorMatrix::new(
33 unsafe {
34 (0..self.height())
36 .flat_map(|r| inner.row_unchecked(self.map_row_index(r)))
37 .collect()
38 },
39 inner.width(),
40 )
41 }
42}
43
44#[derive(Copy, Clone, Debug)]
49pub struct RowIndexMappedView<IndexMap, Inner> {
50 pub index_map: IndexMap,
52 pub inner: Inner,
54}
55
56impl<T: Send + Sync + Clone, IndexMap: RowIndexMap, Inner: Matrix<T>> Matrix<T>
57 for RowIndexMappedView<IndexMap, Inner>
58{
59 fn width(&self) -> usize {
60 self.inner.width()
61 }
62
63 fn height(&self) -> usize {
64 self.index_map.height()
65 }
66
67 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
68 unsafe {
69 self.inner.get_unchecked(self.index_map.map_row_index(r), c)
71 }
72 }
73
74 unsafe fn row_unchecked(
75 &self,
76 r: usize,
77 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
78 unsafe {
79 self.inner.row_unchecked(self.index_map.map_row_index(r))
81 }
82 }
83
84 unsafe fn row_subseq_unchecked(
85 &self,
86 r: usize,
87 start: usize,
88 end: usize,
89 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
90 unsafe {
91 self.inner
93 .row_subseq_unchecked(self.index_map.map_row_index(r), start, end)
94 }
95 }
96
97 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
98 unsafe {
99 self.inner
101 .row_slice_unchecked(self.index_map.map_row_index(r))
102 }
103 }
104
105 unsafe fn row_subslice_unchecked(
106 &self,
107 r: usize,
108 start: usize,
109 end: usize,
110 ) -> impl Deref<Target = [T]> {
111 unsafe {
112 self.inner
114 .row_subslice_unchecked(self.index_map.map_row_index(r), start, end)
115 }
116 }
117
118 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
119 where
120 Self: Sized,
121 T: Clone,
122 {
123 self.index_map.to_row_major_matrix(self.inner)
125 }
126
127 fn horizontally_packed_row<'a, P>(
128 &'a self,
129 r: usize,
130 ) -> (
131 impl Iterator<Item = P> + Send + Sync,
132 impl Iterator<Item = T> + Send + Sync,
133 )
134 where
135 P: PackedValue<Value = T>,
136 T: Clone + 'a,
137 {
138 self.inner
139 .horizontally_packed_row(self.index_map.map_row_index(r))
140 }
141
142 fn padded_horizontally_packed_row<'a, P>(
143 &'a self,
144 r: usize,
145 ) -> impl Iterator<Item = P> + Send + Sync
146 where
147 P: PackedValue<Value = T>,
148 T: Clone + Default + 'a,
149 {
150 self.inner
151 .padded_horizontally_packed_row(self.index_map.map_row_index(r))
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use alloc::vec;
158 use alloc::vec::Vec;
159
160 use itertools::Itertools;
161 use p3_baby_bear::BabyBear;
162 use p3_field::FieldArray;
163
164 use super::*;
165 use crate::dense::RowMajorMatrix;
166
167 struct IdentityMap(usize);
169
170 impl RowIndexMap for IdentityMap {
171 fn height(&self) -> usize {
172 self.0
173 }
174
175 fn map_row_index(&self, r: usize) -> usize {
176 r
177 }
178 }
179
180 struct ReverseMap(usize);
182
183 impl RowIndexMap for ReverseMap {
184 fn height(&self) -> usize {
185 self.0
186 }
187
188 fn map_row_index(&self, r: usize) -> usize {
189 self.0 - 1 - r
190 }
191 }
192
193 struct ConstantMap;
195
196 impl RowIndexMap for ConstantMap {
197 fn height(&self) -> usize {
198 1
199 }
200
201 fn map_row_index(&self, _r: usize) -> usize {
202 0
203 }
204 }
205
206 #[test]
207 fn test_identity_row_index_map() {
208 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
213
214 let mapped_view = RowIndexMappedView {
216 index_map: IdentityMap(inner.height()),
217 inner,
218 };
219
220 assert_eq!(mapped_view.height(), 2);
222 assert_eq!(mapped_view.width(), 3);
223
224 assert_eq!(mapped_view.get(0, 0).unwrap(), 1);
226 assert_eq!(mapped_view.get(1, 2).unwrap(), 6);
227
228 unsafe {
229 assert_eq!(mapped_view.get_unchecked(0, 1), 2);
230 assert_eq!(mapped_view.get_unchecked(1, 0), 4);
231 }
232
233 let rows: Vec<Vec<_>> = mapped_view.rows().map(|row| row.collect()).collect();
235 assert_eq!(rows, vec![vec![1, 2, 3], vec![4, 5, 6]]);
236
237 let dense = mapped_view.to_row_major_matrix();
239 assert_eq!(dense.values, vec![1, 2, 3, 4, 5, 6]);
240 }
241
242 #[test]
243 fn test_reverse_row_index_map() {
244 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
249
250 let mapped_view = RowIndexMappedView {
252 index_map: ReverseMap(inner.height()),
253 inner,
254 };
255
256 assert_eq!(mapped_view.height(), 2);
258 assert_eq!(mapped_view.width(), 3);
259
260 assert_eq!(mapped_view.get(0, 0).unwrap(), 4);
262 assert_eq!(mapped_view.get(1, 2).unwrap(), 3);
264
265 unsafe {
266 assert_eq!(mapped_view.get_unchecked(0, 1), 5);
267 assert_eq!(mapped_view.get_unchecked(1, 0), 1);
268 }
269
270 let rows: Vec<Vec<_>> = mapped_view.rows().map(|row| row.collect()).collect();
272 assert_eq!(rows, vec![vec![4, 5, 6], vec![1, 2, 3]]);
273
274 let dense = mapped_view.to_row_major_matrix();
276 assert_eq!(dense.values, vec![4, 5, 6, 1, 2, 3]);
277 }
278
279 #[test]
280 fn test_horizontally_packed_row() {
281 type Packed = FieldArray<BabyBear, 2>;
283
284 let inner = RowMajorMatrix::new(
289 vec![
290 BabyBear::new(1),
291 BabyBear::new(2),
292 BabyBear::new(3),
293 BabyBear::new(4),
294 ],
295 2,
296 );
297
298 let mapped_view = RowIndexMappedView {
300 index_map: ReverseMap(inner.height()),
301 inner,
302 };
303
304 let (packed_iter, suffix_iter) = mapped_view.horizontally_packed_row::<Packed>(0);
306
307 let packed: Vec<_> = packed_iter.collect();
309 let suffix: Vec<_> = suffix_iter.collect();
310
311 assert_eq!(
313 packed,
314 &[Packed::from([BabyBear::new(3), BabyBear::new(4)])]
315 );
316
317 assert!(suffix.is_empty());
319 }
320
321 #[test]
322 fn test_padded_horizontally_packed_row() {
323 type Packed = FieldArray<BabyBear, 3>;
325
326 let inner = RowMajorMatrix::new(
330 vec![
331 BabyBear::new(1),
332 BabyBear::new(2),
333 BabyBear::new(3),
334 BabyBear::new(4),
335 ],
336 2,
337 );
338
339 let mapped_view = RowIndexMappedView {
341 index_map: IdentityMap(inner.height()),
342 inner,
343 };
344
345 let packed: Vec<_> = mapped_view
347 .padded_horizontally_packed_row::<Packed>(1)
348 .collect();
349
350 assert_eq!(
352 packed,
353 vec![Packed::from([
354 BabyBear::new(3),
355 BabyBear::new(4),
356 BabyBear::new(0),
357 ])]
358 );
359 }
360
361 #[test]
362 fn test_row_and_row_slice_methods() {
363 let inner = RowMajorMatrix::new(vec![10, 20, 30, 40, 50, 60], 3);
367
368 let mapped_view = RowIndexMappedView {
370 index_map: ReverseMap(inner.height()),
371 inner,
372 };
373
374 assert_eq!(mapped_view.row_slice(0).unwrap().deref(), &[40, 50, 60]); assert_eq!(
377 mapped_view.row(1).unwrap().into_iter().collect_vec(),
378 vec![10, 20, 30]
379 ); unsafe {
382 assert_eq!(
384 mapped_view.row_unchecked(0).into_iter().collect_vec(),
385 vec![40, 50, 60]
386 ); assert_eq!(mapped_view.row_slice_unchecked(1).deref(), &[10, 20, 30]); assert_eq!(
390 mapped_view.row_subslice_unchecked(0, 1, 3).deref(),
391 &[50, 60]
392 ); assert_eq!(
394 mapped_view
395 .row_subseq_unchecked(1, 0, 2)
396 .into_iter()
397 .collect_vec(),
398 vec![10, 20]
399 ); }
401
402 assert!(mapped_view.row(2).is_none()); assert!(mapped_view.row_slice(2).is_none()); }
405
406 #[test]
407 fn test_out_of_bounds_access() {
408 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
412
413 let mapped_view = RowIndexMappedView {
415 index_map: IdentityMap(inner.height()),
416 inner,
417 };
418
419 assert_eq!(mapped_view.get(2, 1), None);
421 assert!(mapped_view.row(5).is_none());
422 assert!(mapped_view.row_slice(11).is_none());
423 assert_eq!(mapped_view.get(0, 20), None);
424 }
425
426 #[test]
427 fn test_out_of_bounds_access_with_bad_map() {
428 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 4);
432
433 let mapped_view = RowIndexMappedView {
435 index_map: ConstantMap,
436 inner,
437 };
438
439 assert_eq!(mapped_view.get(0, 2), Some(3));
440
441 assert_eq!(mapped_view.get(1, 0), None);
443 assert!(mapped_view.row(1).is_none());
444 assert!(mapped_view.row_slice(1).is_none());
445 }
446}