p3_matrix/util.rs
1use core::borrow::BorrowMut;
2
3use p3_maybe_rayon::prelude::*;
4use p3_util::{log2_strict_usize, reverse_bits_len};
5use tracing::instrument;
6
7use crate::Matrix;
8use crate::dense::{DenseMatrix, DenseStorage, RowMajorMatrix};
9
10/// Reverse the order of matrix rows based on the bit-reversal of their indices.
11///
12/// Given a matrix `mat` of height `h = 2^k`, this function rearranges its rows by
13/// reversing the binary representation of each row index. For example, if `h = 8` (i.e., 3 bits):
14///
15/// ```text
16/// Original Index Binary Reversed Target Index
17/// -------------- ------- --------- -------------
18/// 0 000 000 0
19/// 1 001 100 4
20/// 2 010 010 2
21/// 3 011 110 6
22/// 4 100 001 1
23/// 5 101 101 5
24/// 6 110 011 3
25/// 7 111 111 7
26/// ```
27///
28/// The transformation is performed in-place.
29///
30/// # Panics
31/// Panics if the height of the matrix is not a power of two.
32///
33/// # Arguments
34/// - `mat`: The matrix whose rows should be reordered.
35#[instrument(level = "debug", skip_all)]
36pub fn reverse_matrix_index_bits<'a, F, S>(mat: &mut DenseMatrix<F, S>)
37where
38 F: Clone + Send + Sync + 'a,
39 S: DenseStorage<F> + BorrowMut<[F]>,
40{
41 let w = mat.width();
42 let h = mat.height();
43 let log_h = log2_strict_usize(h);
44 let values = mat.values.borrow_mut().as_mut_ptr() as usize;
45
46 // SAFETY: Due to the i < j check, we are guaranteed that `swap_rows_raw
47 // will never try and access a particular slice of data more than once
48 // across all parallel threads. Hence the following code is safe and does
49 // not trigger undefined behaviour.
50 (0..h).into_par_iter().for_each(|i| {
51 let values = values as *mut F;
52 let j = reverse_bits_len(i, log_h);
53 if i < j {
54 unsafe { swap_rows_raw(values, w, i, j) };
55 }
56 });
57}
58
59/// Swap two rows `i` and `j` in a [`RowMajorMatrix`].
60///
61/// # Panics
62/// Panics if the indices are out of bounds or not ordered as `i < j`.
63///
64/// # Arguments
65/// - `mat`: The matrix to modify.
66/// - `i`: The first row index (must be less than `j`).
67/// - `j`: The second row index.
68pub fn swap_rows<F: Clone + Send + Sync>(mat: &mut RowMajorMatrix<F>, i: usize, j: usize) {
69 let w = mat.width();
70 let (upper, lower) = mat.values.split_at_mut(j * w);
71 let row_i = &mut upper[i * w..(i + 1) * w];
72 let row_j = &mut lower[..w];
73 row_i.swap_with_slice(row_j);
74}
75
76/// Swap two rows `i` and `j` in-place using raw pointer access.
77///
78/// This function is equivalent to [`swap_rows`] but uses unsafe raw pointer math for better performance.
79///
80/// # Safety
81/// - The caller must ensure `i < j < h`, where `h` is the height of the matrix.
82/// - The pointer must point to a vector corresponding to a matrix of width `w`.
83///
84/// # Arguments
85/// - `mat`: A mutable pointer to the underlying matrix data.
86/// - `w`: The matrix width (number of columns).
87/// - `i`: The first row index.
88/// - `j`: The second row index.
89unsafe fn swap_rows_raw<F>(mat: *mut F, w: usize, i: usize, j: usize) {
90 unsafe {
91 let row_i = core::slice::from_raw_parts_mut(mat.add(i * w), w);
92 let row_j = core::slice::from_raw_parts_mut(mat.add(j * w), w);
93 row_i.swap_with_slice(row_j);
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use alloc::vec;
100
101 use super::*;
102 use crate::dense::RowMajorMatrix;
103
104 #[test]
105 fn test_swap_rows_basic() {
106 let mut matrix = RowMajorMatrix::new(
107 vec![
108 1, 2, 3, // row 0
109 4, 5, 6, // row 1
110 7, 8, 9, // row 2
111 10, 11, 12, // row 3
112 ],
113 3,
114 );
115
116 // Swap rows 0 and 2
117 swap_rows(&mut matrix, 0, 2);
118
119 assert_eq!(
120 matrix.values,
121 vec![
122 7, 8, 9, // row 0 (was row 2)
123 4, 5, 6, // row 1 (unchanged)
124 1, 2, 3, // row 2 (was row 0)
125 10, 11, 12, // row 3 (unchanged)
126 ]
127 );
128 }
129
130 #[test]
131 fn test_swap_rows_raw_basic() {
132 let mut matrix = RowMajorMatrix::new(
133 vec![
134 1, 2, 3, // row 0
135 4, 5, 6, // row 1
136 7, 8, 9, // row 2
137 ],
138 3,
139 );
140 let ptr = matrix.values.as_mut_ptr();
141 unsafe {
142 swap_rows_raw(ptr, matrix.width(), 0, 2);
143 }
144
145 assert_eq!(
146 matrix.values,
147 vec![
148 7, 8, 9, // row 0 (was row 2)
149 4, 5, 6, // row 1 (unchanged)
150 1, 2, 3, // row 2 (was row 0)
151 ]
152 );
153 }
154
155 #[test]
156 fn test_reverse_matrix_index_bits_pow2_height() {
157 let mut matrix = RowMajorMatrix::new(
158 vec![
159 0, 1, // row 0
160 2, 3, // row 1
161 4, 5, // row 2
162 6, 7, // row 3
163 8, 9, // row 4
164 10, 11, // row 5
165 12, 13, // row 6
166 14, 15, // row 7
167 ],
168 2,
169 );
170
171 reverse_matrix_index_bits(&mut matrix);
172
173 assert_eq!(
174 matrix.values,
175 vec![
176 0, 1, // row 0 → index 0b000 → stays at 0
177 8, 9, // row 1 → index 0b001 → goes to index 4
178 4, 5, // row 2 → index 0b010 → stays
179 12, 13, // row 3 → index 0b011 → goes to index 6
180 2, 3, // row 4 → index 0b100 → was row 1
181 10, 11, // row 5 → index 0b101 → stays
182 6, 7, // row 6 → index 0b110 → was row 3
183 14, 15, // row 7 → index 0b111 → stays
184 ]
185 );
186 }
187
188 #[test]
189 fn test_reverse_matrix_index_bits_height_1() {
190 let mut matrix = RowMajorMatrix::new(
191 vec![
192 42, 43, // row 0
193 ],
194 2,
195 );
196
197 // Bit-reversing a height-1 matrix should do nothing.
198 reverse_matrix_index_bits(&mut matrix);
199
200 assert_eq!(
201 matrix.values,
202 vec![
203 42, 43, // row 0 (unchanged)
204 ]
205 );
206 }
207
208 #[test]
209 #[should_panic]
210 fn test_reverse_matrix_index_bits_non_power_of_two_should_panic() {
211 // height = 3 → not a power of two → should panic
212 let mut matrix = RowMajorMatrix::new(
213 vec![
214 1, 2, // row 0
215 3, 4, // row 1
216 5, 6, // row 2
217 ],
218 2,
219 );
220
221 reverse_matrix_index_bits(&mut matrix);
222 }
223}