winter_prover/matrix/segments.rs
1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use alloc::vec::Vec;
7use core::ops::Deref;
8
9use math::{fft::fft_inputs::FftInputs, FieldElement, StarkField};
10#[cfg(feature = "concurrent")]
11use utils::iterators::*;
12use utils::uninit_vector;
13
14use super::ColMatrix;
15
16// CONSTANTS
17// ================================================================================================
18
19/// Segments with domain sizes under this number will be evaluated in a single thread.
20const MIN_CONCURRENT_SIZE: usize = 1024;
21
22// SEGMENT OF ROW-MAJOR MATRIX
23// ================================================================================================
24
25/// A set of columns of a matrix stored in row-major form.
26///
27/// The rows are stored in a single vector where each element is an array of size `N`. A segment
28/// can store [StarkField] elements only, but can be instantiated from a [Matrix] of any extension
29/// of the specified [StarkField]. In such a case, extension field elements are decomposed into
30/// base field elements and then added to the segment.
31#[derive(Clone, Debug)]
32pub struct Segment<B: StarkField, const N: usize> {
33 data: Vec<[B; N]>,
34}
35
36impl<B: StarkField, const N: usize> Segment<B, N> {
37 // CONSTRUCTORS
38 // --------------------------------------------------------------------------------------------
39
40 /// Instantiates a new [Segment] by evaluating polynomials from the provided [ColMatrix]
41 /// starting at the specified offset.
42 ///
43 /// The offset is assumed to be an offset into the view of the matrix where extension field
44 /// elements are decomposed into base field elements. This offset must be compatible with the
45 /// values supplied into [Matrix::get_base_element()] method.
46 ///
47 /// Evaluation is performed over the domain specified by the provided twiddles and offsets.
48 ///
49 /// # Panics
50 /// Panics if:
51 /// - `poly_offset` greater than or equal to the number of base field columns in `polys`.
52 /// - Number of offsets is not a power of two.
53 /// - Number of offsets is smaller than or equal to the polynomial size.
54 /// - The number of twiddles is not half the size of the polynomial size.
55 pub fn new<E>(polys: &ColMatrix<E>, poly_offset: usize, offsets: &[B], twiddles: &[B]) -> Self
56 where
57 E: FieldElement<BaseField = B>,
58 {
59 let poly_size = polys.num_rows();
60 let domain_size = offsets.len();
61 assert!(domain_size.is_power_of_two());
62 assert!(domain_size > poly_size);
63 assert_eq!(poly_size, twiddles.len() * 2);
64 assert!(poly_offset < polys.num_base_cols());
65
66 // allocate memory for the segment
67 let data = if polys.num_base_cols() - poly_offset >= N {
68 // if we will fill the entire segment, we allocate uninitialized memory
69 unsafe { uninit_vector::<[B; N]>(domain_size) }
70 } else {
71 // but if some columns in the segment will remain unfilled, we allocate memory
72 // initialized to zeros to make sure we don't end up with memory with
73 // undefined values
74 vec![[B::ZERO; N]; domain_size]
75 };
76
77 Self::new_with_buffer(data, polys, poly_offset, offsets, twiddles)
78 }
79
80 /// Instantiates a new [Segment] using the provided data buffer by evaluating polynomials in
81 /// the [ColMatrix] starting at the specified offset.
82 ///
83 /// The offset is assumed to be an offset into the view of the matrix where extension field
84 /// elements are decomposed into base field elements. This offset must be compatible with the
85 /// values supplied into [Matrix::get_base_element()] method.
86 ///
87 /// Evaluation is performed over the domain specified by the provided twiddles and offsets.
88 ///
89 /// # Panics
90 /// Panics if:
91 /// - `poly_offset` greater than or equal to the number of base field columns in `polys`.
92 /// - Number of offsets is not a power of two.
93 /// - Number of offsets is smaller than or equal to the polynomial size.
94 /// - The number of twiddles is not half the size of the polynomial size.
95 /// - Number of offsets is smaller than the length of the data buffer
96 pub fn new_with_buffer<E>(
97 data_buffer: Vec<[B; N]>,
98 polys: &ColMatrix<E>,
99 poly_offset: usize,
100 offsets: &[B],
101 twiddles: &[B],
102 ) -> Self
103 where
104 E: FieldElement<BaseField = B>,
105 {
106 let poly_size = polys.num_rows();
107 let domain_size = offsets.len();
108 let mut data = data_buffer;
109
110 assert!(domain_size.is_power_of_two());
111 assert!(domain_size > poly_size);
112 assert_eq!(poly_size, twiddles.len() * 2);
113 assert!(poly_offset < polys.num_base_cols());
114 assert_eq!(data.len(), domain_size);
115
116 // determine the number of polynomials to add to this segment; this number can be either N,
117 // or smaller than N when there are fewer than N polynomials remaining to be processed
118 let num_polys_remaining = polys.num_base_cols() - poly_offset;
119 let num_polys = if num_polys_remaining < N {
120 num_polys_remaining
121 } else {
122 N
123 };
124
125 // evaluate the polynomials either in a single thread or multiple threads, depending
126 // on whether `concurrent` feature is enabled and domain size is greater than 1024;
127
128 if cfg!(feature = "concurrent") && domain_size >= MIN_CONCURRENT_SIZE {
129 #[cfg(feature = "concurrent")]
130 data.par_chunks_mut(poly_size).zip(offsets.par_chunks(poly_size)).for_each(
131 |(d_chunk, o_chunk)| {
132 // TODO: investigate multi-threaded copy
133 if num_polys == N {
134 Self::copy_polys(d_chunk, polys, poly_offset, o_chunk);
135 } else {
136 Self::copy_polys_partial(d_chunk, polys, poly_offset, num_polys, o_chunk);
137 }
138 concurrent::split_radix_fft(d_chunk, twiddles);
139 },
140 );
141 #[cfg(feature = "concurrent")]
142 concurrent::permute(&mut data);
143 } else {
144 data.chunks_mut(poly_size).zip(offsets.chunks(poly_size)).for_each(
145 |(d_chunk, o_chunk)| {
146 if num_polys == N {
147 Self::copy_polys(d_chunk, polys, poly_offset, o_chunk);
148 } else {
149 Self::copy_polys_partial(d_chunk, polys, poly_offset, num_polys, o_chunk);
150 }
151 d_chunk.fft_in_place(twiddles);
152 },
153 );
154 data.permute();
155 }
156
157 Segment { data }
158 }
159
160 // PUBLIC ACCESSORS
161 // --------------------------------------------------------------------------------------------
162
163 /// Returns the number of rows in this segment.
164 pub fn num_rows(&self) -> usize {
165 self.data.len()
166 }
167
168 /// Returns the underlying vector of arrays for this segment.
169 pub fn into_data(self) -> Vec<[B; N]> {
170 self.data
171 }
172
173 // HELPER METHODS
174 // --------------------------------------------------------------------------------------------
175
176 /// Copies N polynomials starting at the specified base column offset (`poly_offset`) into the
177 /// specified destination. Each polynomial coefficient is offset by the specified offset.
178 fn copy_polys<E: FieldElement<BaseField = B>>(
179 dest: &mut [[B; N]],
180 polys: &ColMatrix<E>,
181 poly_offset: usize,
182 offsets: &[B],
183 ) {
184 for row_idx in 0..dest.len() {
185 for i in 0..N {
186 let coeff = polys.get_base_element(poly_offset + i, row_idx);
187 dest[row_idx][i] = coeff * offsets[row_idx];
188 }
189 }
190 }
191
192 /// Similar to `clone_and_shift` method above, but copies `num_polys` polynomials instead of
193 /// `N` polynomials.
194 ///
195 /// Assumes that `num_polys` is smaller than `N`.
196 fn copy_polys_partial<E: FieldElement<BaseField = B>>(
197 dest: &mut [[B; N]],
198 polys: &ColMatrix<E>,
199 poly_offset: usize,
200 num_polys: usize,
201 offsets: &[B],
202 ) {
203 debug_assert!(num_polys < N);
204 for row_idx in 0..dest.len() {
205 for i in 0..num_polys {
206 let coeff = polys.get_base_element(poly_offset + i, row_idx);
207 dest[row_idx][i] = coeff * offsets[row_idx];
208 }
209 }
210 }
211}
212
213impl<B: StarkField, const N: usize> Deref for Segment<B, N> {
214 type Target = Vec<[B; N]>;
215
216 fn deref(&self) -> &Self::Target {
217 &self.data
218 }
219}
220
221// CONCURRENT FFT IMPLEMENTATION
222// ================================================================================================
223
224/// Multi-threaded implementations of FFT and permutation algorithms. These are very similar to
225/// the functions implemented in `winter-math::fft::concurrent` module, but are adapted to work
226/// with slices of element arrays.
227#[cfg(feature = "concurrent")]
228mod concurrent {
229 use math::fft::permute_index;
230 use utils::{iterators::*, rayon};
231
232 use super::{FftInputs, StarkField};
233
234 /// In-place recursive FFT with permuted output.
235 /// Adapted from: https://github.com/0xProject/OpenZKP/tree/master/algebra/primefield/src/fft
236 #[allow(clippy::needless_range_loop)]
237 pub fn split_radix_fft<B: StarkField, const N: usize>(data: &mut [[B; N]], twiddles: &[B]) {
238 // generator of the domain should be in the middle of twiddles
239 let n = data.len();
240 let g = twiddles[twiddles.len() / 2];
241 debug_assert_eq!(g.exp((n as u32).into()), B::ONE);
242
243 let inner_len = 1_usize << (n.ilog2() / 2);
244 let outer_len = n / inner_len;
245 let stretch = outer_len / inner_len;
246 debug_assert!(outer_len == inner_len || outer_len == 2 * inner_len);
247 debug_assert_eq!(outer_len * inner_len, n);
248
249 // transpose inner x inner x stretch square matrix
250 transpose_square_stretch(data, inner_len, stretch);
251
252 // apply inner FFTs
253 data.par_chunks_mut(outer_len)
254 .for_each(|row| row.fft_in_place_raw(twiddles, stretch, stretch, 0));
255
256 // transpose inner x inner x stretch square matrix
257 transpose_square_stretch(data, inner_len, stretch);
258
259 // apply outer FFTs
260 data.par_chunks_mut(outer_len).enumerate().for_each(|(i, row)| {
261 if i > 0 {
262 let i = permute_index(inner_len, i);
263 let inner_twiddle = g.exp_vartime((i as u32).into());
264 let mut outer_twiddle = inner_twiddle;
265 for element in row.iter_mut().skip(1) {
266 for col_idx in 0..N {
267 element[col_idx] *= outer_twiddle;
268 }
269 outer_twiddle *= inner_twiddle;
270 }
271 }
272 row.fft_in_place(twiddles)
273 });
274 }
275
276 // PERMUTATIONS
277 // --------------------------------------------------------------------------------------------
278
279 pub fn permute<T: Send>(v: &mut [T]) {
280 let n = v.len();
281 let num_batches = rayon::current_num_threads().next_power_of_two() * 2;
282 let batch_size = n / num_batches;
283 rayon::scope(|s| {
284 for batch_idx in 0..num_batches {
285 // create another mutable reference to the slice of values to use in a new thread;
286 // this is OK because we never write the same positions in the slice from different
287 // threads
288 let values = unsafe { &mut *(&mut v[..] as *mut [T]) };
289 s.spawn(move |_| {
290 let batch_start = batch_idx * batch_size;
291 let batch_end = batch_start + batch_size;
292 for i in batch_start..batch_end {
293 let j = permute_index(n, i);
294 if j > i {
295 values.swap(i, j);
296 }
297 }
298 });
299 }
300 });
301 }
302
303 // TRANSPOSING
304 // --------------------------------------------------------------------------------------------
305
306 fn transpose_square_stretch<T>(data: &mut [T], size: usize, stretch: usize) {
307 assert_eq!(data.len(), size * size * stretch);
308 match stretch {
309 1 => transpose_square_1(data, size),
310 2 => transpose_square_2(data, size),
311 _ => unimplemented!("only stretch sizes 1 and 2 are supported"),
312 }
313 }
314
315 fn transpose_square_1<T>(data: &mut [T], size: usize) {
316 debug_assert_eq!(data.len(), size * size);
317 debug_assert_eq!(size % 2, 0, "odd sizes are not supported");
318
319 // iterate over upper-left triangle, working in 2x2 blocks
320 // TODO: investigate concurrent implementation
321 for row in (0..size).step_by(2) {
322 let i = row * size + row;
323 data.swap(i + 1, i + size);
324 for col in (row..size).step_by(2).skip(1) {
325 let i = row * size + col;
326 let j = col * size + row;
327 data.swap(i, j);
328 data.swap(i + 1, j + size);
329 data.swap(i + size, j + 1);
330 data.swap(i + size + 1, j + size + 1);
331 }
332 }
333 }
334
335 fn transpose_square_2<T>(data: &mut [T], size: usize) {
336 debug_assert_eq!(data.len(), 2 * size * size);
337
338 // iterate over upper-left triangle, working in 1x2 blocks
339 // TODO: investigate concurrent implementation
340 for row in 0..size {
341 for col in (row..size).skip(1) {
342 let i = (row * size + col) * 2;
343 let j = (col * size + row) * 2;
344 data.swap(i, j);
345 data.swap(i + 1, j + 1);
346 }
347 }
348 }
349}