use alloc::vec::Vec;
use core::ops::Deref;
use math::{fft::fft_inputs::FftInputs, FieldElement, StarkField};
#[cfg(feature = "concurrent")]
use utils::iterators::*;
use utils::uninit_vector;
use super::ColMatrix;
const MIN_CONCURRENT_SIZE: usize = 1024;
#[derive(Clone, Debug)]
pub struct Segment<B: StarkField, const N: usize> {
data: Vec<[B; N]>,
}
impl<B: StarkField, const N: usize> Segment<B, N> {
pub fn new<E>(polys: &ColMatrix<E>, poly_offset: usize, offsets: &[B], twiddles: &[B]) -> Self
where
E: FieldElement<BaseField = B>,
{
let poly_size = polys.num_rows();
let domain_size = offsets.len();
assert!(domain_size.is_power_of_two());
assert!(domain_size > poly_size);
assert_eq!(poly_size, twiddles.len() * 2);
assert!(poly_offset < polys.num_base_cols());
let data = if polys.num_base_cols() - poly_offset >= N {
unsafe { uninit_vector::<[B; N]>(domain_size) }
} else {
vec![[B::ZERO; N]; domain_size]
};
Self::new_with_buffer(data, polys, poly_offset, offsets, twiddles)
}
pub fn new_with_buffer<E>(
data_buffer: Vec<[B; N]>,
polys: &ColMatrix<E>,
poly_offset: usize,
offsets: &[B],
twiddles: &[B],
) -> Self
where
E: FieldElement<BaseField = B>,
{
let poly_size = polys.num_rows();
let domain_size = offsets.len();
let mut data = data_buffer;
assert!(domain_size.is_power_of_two());
assert!(domain_size > poly_size);
assert_eq!(poly_size, twiddles.len() * 2);
assert!(poly_offset < polys.num_base_cols());
assert_eq!(data.len(), domain_size);
let num_polys_remaining = polys.num_base_cols() - poly_offset;
let num_polys = if num_polys_remaining < N {
num_polys_remaining
} else {
N
};
if cfg!(feature = "concurrent") && domain_size >= MIN_CONCURRENT_SIZE {
#[cfg(feature = "concurrent")]
data.par_chunks_mut(poly_size).zip(offsets.par_chunks(poly_size)).for_each(
|(d_chunk, o_chunk)| {
if num_polys == N {
Self::copy_polys(d_chunk, polys, poly_offset, o_chunk);
} else {
Self::copy_polys_partial(d_chunk, polys, poly_offset, num_polys, o_chunk);
}
concurrent::split_radix_fft(d_chunk, twiddles);
},
);
#[cfg(feature = "concurrent")]
concurrent::permute(&mut data);
} else {
data.chunks_mut(poly_size).zip(offsets.chunks(poly_size)).for_each(
|(d_chunk, o_chunk)| {
if num_polys == N {
Self::copy_polys(d_chunk, polys, poly_offset, o_chunk);
} else {
Self::copy_polys_partial(d_chunk, polys, poly_offset, num_polys, o_chunk);
}
d_chunk.fft_in_place(twiddles);
},
);
data.permute();
}
Segment { data }
}
pub fn num_rows(&self) -> usize {
self.data.len()
}
pub fn into_data(self) -> Vec<[B; N]> {
self.data
}
fn copy_polys<E: FieldElement<BaseField = B>>(
dest: &mut [[B; N]],
polys: &ColMatrix<E>,
poly_offset: usize,
offsets: &[B],
) {
for row_idx in 0..dest.len() {
for i in 0..N {
let coeff = polys.get_base_element(poly_offset + i, row_idx);
dest[row_idx][i] = coeff * offsets[row_idx];
}
}
}
fn copy_polys_partial<E: FieldElement<BaseField = B>>(
dest: &mut [[B; N]],
polys: &ColMatrix<E>,
poly_offset: usize,
num_polys: usize,
offsets: &[B],
) {
debug_assert!(num_polys < N);
for row_idx in 0..dest.len() {
for i in 0..num_polys {
let coeff = polys.get_base_element(poly_offset + i, row_idx);
dest[row_idx][i] = coeff * offsets[row_idx];
}
}
}
}
impl<B: StarkField, const N: usize> Deref for Segment<B, N> {
type Target = Vec<[B; N]>;
fn deref(&self) -> &Self::Target {
&self.data
}
}
#[cfg(feature = "concurrent")]
mod concurrent {
use math::fft::permute_index;
use utils::{iterators::*, rayon};
use super::{FftInputs, StarkField};
#[allow(clippy::needless_range_loop)]
pub fn split_radix_fft<B: StarkField, const N: usize>(data: &mut [[B; N]], twiddles: &[B]) {
let n = data.len();
let g = twiddles[twiddles.len() / 2];
debug_assert_eq!(g.exp((n as u32).into()), B::ONE);
let inner_len = 1_usize << (n.ilog2() / 2);
let outer_len = n / inner_len;
let stretch = outer_len / inner_len;
debug_assert!(outer_len == inner_len || outer_len == 2 * inner_len);
debug_assert_eq!(outer_len * inner_len, n);
transpose_square_stretch(data, inner_len, stretch);
data.par_chunks_mut(outer_len)
.for_each(|row| row.fft_in_place_raw(twiddles, stretch, stretch, 0));
transpose_square_stretch(data, inner_len, stretch);
data.par_chunks_mut(outer_len).enumerate().for_each(|(i, row)| {
if i > 0 {
let i = permute_index(inner_len, i);
let inner_twiddle = g.exp_vartime((i as u32).into());
let mut outer_twiddle = inner_twiddle;
for element in row.iter_mut().skip(1) {
for col_idx in 0..N {
element[col_idx] *= outer_twiddle;
}
outer_twiddle *= inner_twiddle;
}
}
row.fft_in_place(twiddles)
});
}
pub fn permute<T: Send>(v: &mut [T]) {
let n = v.len();
let num_batches = rayon::current_num_threads().next_power_of_two() * 2;
let batch_size = n / num_batches;
rayon::scope(|s| {
for batch_idx in 0..num_batches {
let values = unsafe { &mut *(&mut v[..] as *mut [T]) };
s.spawn(move |_| {
let batch_start = batch_idx * batch_size;
let batch_end = batch_start + batch_size;
for i in batch_start..batch_end {
let j = permute_index(n, i);
if j > i {
values.swap(i, j);
}
}
});
}
});
}
fn transpose_square_stretch<T>(data: &mut [T], size: usize, stretch: usize) {
assert_eq!(data.len(), size * size * stretch);
match stretch {
1 => transpose_square_1(data, size),
2 => transpose_square_2(data, size),
_ => unimplemented!("only stretch sizes 1 and 2 are supported"),
}
}
fn transpose_square_1<T>(data: &mut [T], size: usize) {
debug_assert_eq!(data.len(), size * size);
debug_assert_eq!(size % 2, 0, "odd sizes are not supported");
for row in (0..size).step_by(2) {
let i = row * size + row;
data.swap(i + 1, i + size);
for col in (row..size).step_by(2).skip(1) {
let i = row * size + col;
let j = col * size + row;
data.swap(i, j);
data.swap(i + 1, j + size);
data.swap(i + size, j + 1);
data.swap(i + size + 1, j + size + 1);
}
}
}
fn transpose_square_2<T>(data: &mut [T], size: usize) {
debug_assert_eq!(data.len(), 2 * size * size);
for row in 0..size {
for col in (row..size).skip(1) {
let i = (row * size + col) * 2;
let j = (col * size + row) * 2;
data.swap(i, j);
data.swap(i + 1, j + 1);
}
}
}
}