1use alloc::vec::Vec;
4use core::array;
5
6use p3_field::PackedValue;
7use p3_matrix::{Matrix, dense::RowMajorMatrix};
8use p3_util::log2_strict_usize;
9use serde::{Deserialize, Serialize};
10
11#[inline]
15pub fn log2_strict_u8(n: usize) -> u8 {
16 log2_strict_usize(n) as u8
17}
18
19#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>"))]
37pub struct RowList<T> {
38 elems: Vec<T>,
39 widths: Vec<usize>,
40}
41
42impl<T> RowList<T> {
43 pub fn new(elems: Vec<T>, widths: &[usize]) -> Self {
49 let expected: usize = widths.iter().sum();
50 assert_eq!(
51 elems.len(),
52 expected,
53 "RowList invariant violated: {} elems but widths sum to {}",
54 elems.len(),
55 expected,
56 );
57 Self {
58 elems,
59 widths: widths.to_vec(),
60 }
61 }
62
63 pub fn from_rows<R: AsRef<[T]>>(rows: impl IntoIterator<Item = R>) -> Self
67 where
68 T: Clone,
69 {
70 let mut elems = Vec::new();
71 let mut widths = Vec::new();
72 for row in rows {
73 let row = row.as_ref();
74 widths.push(row.len());
75 elems.extend_from_slice(row);
76 }
77 Self { elems, widths }
78 }
79
80 #[inline]
82 pub fn as_slice(&self) -> &[T] {
83 &self.elems
84 }
85
86 #[inline]
88 pub fn iter_values(&self) -> impl Iterator<Item = T> + '_
89 where
90 T: Copy,
91 {
92 self.elems.iter().copied()
93 }
94
95 #[inline]
97 pub fn num_rows(&self) -> usize {
98 self.widths.len()
99 }
100
101 pub fn iter_rows(&self) -> impl Iterator<Item = &[T]> {
103 let mut offset = 0;
104 self.widths.iter().map(move |&w| {
105 let row = &self.elems[offset..offset + w];
106 offset += w;
107 row
108 })
109 }
110
111 pub fn row(&self, idx: usize) -> &[T] {
117 let offset: usize = self.widths[..idx].iter().sum();
118 &self.elems[offset..offset + self.widths[idx]]
119 }
120}
121
122impl<T: Copy + Default> RowList<T> {
123 pub fn iter_aligned(&self, alignment: usize) -> impl Iterator<Item = T> + '_ {
132 self.iter_rows().flat_map(move |row| {
133 let padding = aligned_len(row.len(), alignment) - row.len();
134 row.iter()
135 .copied()
136 .chain(core::iter::repeat_n(T::default(), padding))
137 })
138 }
139}
140
141impl<T: Default + Clone> RowList<T> {
142 pub fn from_rows_aligned<R: AsRef<[T]>>(
144 rows: impl IntoIterator<Item = R>,
145 alignment: usize,
146 ) -> Self {
147 let mut elems = Vec::new();
148 let mut widths = Vec::new();
149 for row in rows {
150 let row = row.as_ref();
151 let padded_len = aligned_len(row.len(), alignment);
152 widths.push(padded_len);
153 elems.extend_from_slice(row);
154 elems.resize(elems.len() + (padded_len - row.len()), T::default());
155 }
156 Self { elems, widths }
157 }
158}
159
160pub trait PackedValueExt: PackedValue {
165 #[inline]
170 #[must_use]
171 fn pack_columns<const N: usize>(rows: &[[Self::Value; N]]) -> [Self; N] {
172 assert_eq!(rows.len(), Self::WIDTH);
173 array::from_fn(|col| Self::from_fn(|lane| rows[lane][col]))
174 }
175}
176
177impl<T: PackedValue> PackedValueExt for T {}
179
180#[inline]
182pub const fn aligned_len(len: usize, alignment: usize) -> usize {
183 if alignment <= 1 {
184 len
185 } else {
186 len.next_multiple_of(alignment)
187 }
188}
189
190pub fn aligned_widths(mut widths: Vec<usize>, alignment: usize) -> Vec<usize> {
192 for w in &mut widths {
193 *w = aligned_len(*w, alignment);
194 }
195 widths
196}
197
198pub fn pad_row_to_alignment<F: Default>(mut row: Vec<F>, alignment: usize) -> Vec<F> {
203 debug_assert!(alignment > 0, "alignment must be non-zero");
204 let padded_len = aligned_len(row.len(), alignment);
205 row.resize_with(padded_len, || F::default());
206 row
207}
208
209pub fn upsample_matrix<F: Clone + Send + Sync>(
217 matrix: &impl Matrix<F>,
218 target_height: usize,
219) -> RowMajorMatrix<F> {
220 let height = matrix.height();
221 assert!(target_height >= height);
222 assert!(height.is_power_of_two() && target_height.is_power_of_two());
223
224 let repeat_factor = target_height / height;
225 let width = matrix.width();
226
227 let mut values = Vec::with_capacity(target_height * width);
228 for row in matrix.rows() {
229 let row_vec: Vec<F> = row.collect();
230 for _ in 0..repeat_factor {
231 values.extend(row_vec.iter().cloned());
232 }
233 }
234
235 RowMajorMatrix::new(values, width)
236}