p3_matrix/
util.rs

1use core::borrow::BorrowMut;
2
3use p3_maybe_rayon::prelude::*;
4use p3_util::{log2_strict_usize, reverse_bits_len};
5use tracing::instrument;
6
7use crate::Matrix;
8use crate::dense::{DenseMatrix, DenseStorage, RowMajorMatrix};
9
10/// Reverse the order of matrix rows based on the bit-reversal of their indices.
11///
12/// Given a matrix `mat` of height `h = 2^k`, this function rearranges its rows by
13/// reversing the binary representation of each row index. For example, if `h = 8` (i.e., 3 bits):
14///
15/// ```text
16/// Original Index  Binary   Reversed   Target Index
17/// --------------  -------  ---------  -------------
18///      0          000      000        0
19///      1          001      100        4
20///      2          010      010        2
21///      3          011      110        6
22///      4          100      001        1
23///      5          101      101        5
24///      6          110      011        3
25///      7          111      111        7
26/// ```
27///
28/// The transformation is performed in-place.
29///
30/// # Panics
31/// Panics if the height of the matrix is not a power of two.
32///
33/// # Arguments
34/// - `mat`: The matrix whose rows should be reordered.
35#[instrument(level = "debug", skip_all)]
36pub fn reverse_matrix_index_bits<'a, F, S>(mat: &mut DenseMatrix<F, S>)
37where
38    F: Clone + Send + Sync + 'a,
39    S: DenseStorage<F> + BorrowMut<[F]>,
40{
41    let w = mat.width();
42    let h = mat.height();
43    let log_h = log2_strict_usize(h);
44    let values = mat.values.borrow_mut().as_mut_ptr() as usize;
45
46    // SAFETY: Due to the i < j check, we are guaranteed that `swap_rows_raw
47    // will never try and access a particular slice of data more than once
48    // across all parallel threads. Hence the following code is safe and does
49    // not trigger undefined behaviour.
50    (0..h).into_par_iter().for_each(|i| {
51        let values = values as *mut F;
52        let j = reverse_bits_len(i, log_h);
53        if i < j {
54            unsafe { swap_rows_raw(values, w, i, j) };
55        }
56    });
57}
58
59/// Swap two rows `i` and `j` in a [`RowMajorMatrix`].
60///
61/// # Panics
62/// Panics if the indices are out of bounds or not ordered as `i < j`.
63///
64/// # Arguments
65/// - `mat`: The matrix to modify.
66/// - `i`: The first row index (must be less than `j`).
67/// - `j`: The second row index.
68pub fn swap_rows<F: Clone + Send + Sync>(mat: &mut RowMajorMatrix<F>, i: usize, j: usize) {
69    let w = mat.width();
70    let (upper, lower) = mat.values.split_at_mut(j * w);
71    let row_i = &mut upper[i * w..(i + 1) * w];
72    let row_j = &mut lower[..w];
73    row_i.swap_with_slice(row_j);
74}
75
76/// Swap two rows `i` and `j` in-place using raw pointer access.
77///
78/// This function is equivalent to [`swap_rows`] but uses unsafe raw pointer math for better performance.
79///
80/// # Safety
81/// - The caller must ensure `i < j < h`, where `h` is the height of the matrix.
82/// - The pointer must point to a vector corresponding to a matrix of width `w`.
83///
84/// # Arguments
85/// - `mat`: A mutable pointer to the underlying matrix data.
86/// - `w`: The matrix width (number of columns).
87/// - `i`: The first row index.
88/// - `j`: The second row index.
89unsafe fn swap_rows_raw<F>(mat: *mut F, w: usize, i: usize, j: usize) {
90    unsafe {
91        let row_i = core::slice::from_raw_parts_mut(mat.add(i * w), w);
92        let row_j = core::slice::from_raw_parts_mut(mat.add(j * w), w);
93        row_i.swap_with_slice(row_j);
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use alloc::vec;
100
101    use super::*;
102    use crate::dense::RowMajorMatrix;
103
104    #[test]
105    fn test_swap_rows_basic() {
106        let mut matrix = RowMajorMatrix::new(
107            vec![
108                1, 2, 3, // row 0
109                4, 5, 6, // row 1
110                7, 8, 9, // row 2
111                10, 11, 12, // row 3
112            ],
113            3,
114        );
115
116        // Swap rows 0 and 2
117        swap_rows(&mut matrix, 0, 2);
118
119        assert_eq!(
120            matrix.values,
121            vec![
122                7, 8, 9, // row 0 (was row 2)
123                4, 5, 6, // row 1 (unchanged)
124                1, 2, 3, // row 2 (was row 0)
125                10, 11, 12, // row 3 (unchanged)
126            ]
127        );
128    }
129
130    #[test]
131    fn test_swap_rows_raw_basic() {
132        let mut matrix = RowMajorMatrix::new(
133            vec![
134                1, 2, 3, // row 0
135                4, 5, 6, // row 1
136                7, 8, 9, // row 2
137            ],
138            3,
139        );
140        let ptr = matrix.values.as_mut_ptr();
141        unsafe {
142            swap_rows_raw(ptr, matrix.width(), 0, 2);
143        }
144
145        assert_eq!(
146            matrix.values,
147            vec![
148                7, 8, 9, // row 0 (was row 2)
149                4, 5, 6, // row 1 (unchanged)
150                1, 2, 3, // row 2 (was row 0)
151            ]
152        );
153    }
154
155    #[test]
156    fn test_reverse_matrix_index_bits_pow2_height() {
157        let mut matrix = RowMajorMatrix::new(
158            vec![
159                0, 1, // row 0
160                2, 3, // row 1
161                4, 5, // row 2
162                6, 7, // row 3
163                8, 9, // row 4
164                10, 11, // row 5
165                12, 13, // row 6
166                14, 15, // row 7
167            ],
168            2,
169        );
170
171        reverse_matrix_index_bits(&mut matrix);
172
173        assert_eq!(
174            matrix.values,
175            vec![
176                0, 1, // row 0 → index 0b000 → stays at 0
177                8, 9, // row 1 → index 0b001 → goes to index 4
178                4, 5, // row 2 → index 0b010 → stays
179                12, 13, // row 3 → index 0b011 → goes to index 6
180                2, 3, // row 4 → index 0b100 → was row 1
181                10, 11, // row 5 → index 0b101 → stays
182                6, 7, // row 6 → index 0b110 → was row 3
183                14, 15, // row 7 → index 0b111 → stays
184            ]
185        );
186    }
187
188    #[test]
189    fn test_reverse_matrix_index_bits_height_1() {
190        let mut matrix = RowMajorMatrix::new(
191            vec![
192                42, 43, // row 0
193            ],
194            2,
195        );
196
197        // Bit-reversing a height-1 matrix should do nothing.
198        reverse_matrix_index_bits(&mut matrix);
199
200        assert_eq!(
201            matrix.values,
202            vec![
203                42, 43, // row 0 (unchanged)
204            ]
205        );
206    }
207
208    #[test]
209    #[should_panic]
210    fn test_reverse_matrix_index_bits_non_power_of_two_should_panic() {
211        // height = 3 → not a power of two → should panic
212        let mut matrix = RowMajorMatrix::new(
213            vec![
214                1, 2, // row 0
215                3, 4, // row 1
216                5, 6, // row 2
217            ],
218            2,
219        );
220
221        reverse_matrix_index_bits(&mut matrix);
222    }
223}