use alloc::vec::Vec;
use air::PartitionOptions;
use crypto::{ElementHasher, VectorCommitment};
use math::{fft, FieldElement, StarkField};
#[cfg(feature = "concurrent")]
use utils::iterators::*;
use utils::{batch_iter_mut, flatten_vector_elements, uninit_vector};
use super::{ColMatrix, Segment};
use crate::StarkDomain;
#[derive(Clone, Debug)]
pub struct RowMatrix<E: FieldElement> {
data: Vec<E::BaseField>,
row_width: usize,
elements_per_row: usize,
}
impl<E: FieldElement> RowMatrix<E> {
pub fn evaluate_polys<const N: usize>(polys: &ColMatrix<E>, blowup_factor: usize) -> Self {
assert!(N > 0, "batch size N must be greater than zero");
let poly_size = polys.num_rows();
let offsets =
get_evaluation_offsets::<E>(poly_size, blowup_factor, E::BaseField::GENERATOR);
let twiddles = fft::get_twiddles::<E::BaseField>(polys.num_rows());
let segments = build_segments::<E, N>(polys, &twiddles, &offsets);
Self::from_segments(segments, polys.num_base_cols())
}
pub fn evaluate_polys_over<const N: usize>(
polys: &ColMatrix<E>,
domain: &StarkDomain<E::BaseField>,
) -> Self {
assert!(N > 0, "batch size N must be greater than zero");
let poly_size = polys.num_rows();
let offsets =
get_evaluation_offsets::<E>(poly_size, domain.trace_to_lde_blowup(), domain.offset());
let segments = build_segments::<E, N>(polys, domain.trace_twiddles(), &offsets);
Self::from_segments(segments, polys.num_base_cols())
}
pub fn from_segments<const N: usize>(
segments: Vec<Segment<E::BaseField, N>>,
elements_per_row: usize,
) -> Self {
assert!(N > 0, "batch size N must be greater than zero");
assert!(!segments.is_empty(), "a list of segments cannot be empty");
let row_width = segments.len() * N;
assert!(
elements_per_row <= row_width,
"elements per row cannot exceed {row_width}, but was {elements_per_row}"
);
let result = transpose(segments);
RowMatrix {
data: flatten_vector_elements(result),
row_width,
elements_per_row,
}
}
pub fn num_cols(&self) -> usize {
self.elements_per_row / E::EXTENSION_DEGREE
}
pub fn num_rows(&self) -> usize {
self.data.len() / self.row_width
}
pub fn get(&self, col_idx: usize, row_idx: usize) -> E {
self.row(row_idx)[col_idx]
}
pub fn row(&self, row_idx: usize) -> &[E] {
assert!(row_idx < self.num_rows());
let start = row_idx * self.row_width;
E::slice_from_base_elements(&self.data[start..start + self.elements_per_row])
}
pub fn data(&self) -> &[E::BaseField] {
&self.data
}
pub fn commit_to_rows<H, V>(&self, partition_options: PartitionOptions) -> V
where
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
let mut row_hashes = unsafe { uninit_vector::<H::Digest>(self.num_rows()) };
let partition_size = partition_options.partition_size::<E>(self.num_cols());
if partition_size == self.num_cols() {
batch_iter_mut!(
&mut row_hashes,
128, |batch: &mut [H::Digest], batch_offset: usize| {
for (i, row_hash) in batch.iter_mut().enumerate() {
*row_hash = H::hash_elements(self.row(batch_offset + i));
}
}
);
} else {
let num_partitions = partition_options.num_partitions::<E>(self.num_cols());
batch_iter_mut!(
&mut row_hashes,
128, |batch: &mut [H::Digest], batch_offset: usize| {
let mut buffer = vec![H::Digest::default(); num_partitions];
for (i, row_hash) in batch.iter_mut().enumerate() {
self.row(batch_offset + i)
.chunks(partition_size)
.zip(buffer.iter_mut())
.for_each(|(chunk, buf)| {
*buf = H::hash_elements(chunk);
});
*row_hash = H::merge_many(&buffer);
}
}
);
}
V::new(row_hashes).expect("failed to construct trace vector commitment")
}
}
pub fn get_evaluation_offsets<E: FieldElement>(
poly_size: usize,
blowup_factor: usize,
domain_offset: E::BaseField,
) -> Vec<E::BaseField> {
let domain_size = poly_size * blowup_factor;
let g = E::BaseField::get_root_of_unity(domain_size.ilog2());
let mut offsets = unsafe { uninit_vector(domain_size) };
let compute_offsets = |(chunk_idx, chunk): (usize, &mut [E::BaseField])| {
let idx = fft::permute_index(blowup_factor, chunk_idx) as u64;
let offset = g.exp_vartime(idx.into()) * domain_offset;
let mut factor = E::BaseField::ONE;
for res in chunk.iter_mut() {
*res = factor;
factor *= offset;
}
};
#[cfg(not(feature = "concurrent"))]
offsets.chunks_mut(poly_size).enumerate().for_each(compute_offsets);
#[cfg(feature = "concurrent")]
offsets.par_chunks_mut(poly_size).enumerate().for_each(compute_offsets);
offsets
}
pub fn build_segments<E: FieldElement, const N: usize>(
polys: &ColMatrix<E>,
twiddles: &[E::BaseField],
offsets: &[E::BaseField],
) -> Vec<Segment<E::BaseField, N>> {
assert!(N > 0, "batch size N must be greater than zero");
debug_assert_eq!(polys.num_rows(), twiddles.len() * 2);
debug_assert_eq!(offsets.len() % polys.num_rows(), 0);
let num_segments = if polys.num_base_cols().is_multiple_of(N) {
polys.num_base_cols() / N
} else {
polys.num_base_cols() / N + 1
};
(0..num_segments)
.map(|i| Segment::new(polys, i * N, offsets, twiddles))
.collect()
}
fn transpose<B: StarkField, const N: usize>(mut segments: Vec<Segment<B, N>>) -> Vec<[B; N]> {
let num_rows = segments[0].num_rows();
let num_segs = segments.len();
let result_len = num_rows * num_segs;
if segments.len() == 1 {
return segments.remove(0).into_data();
}
let mut result = unsafe { uninit_vector::<[B; N]>(result_len) };
let num_batches = get_num_batches(result_len);
let rows_per_batch = num_rows / num_batches;
let transpose_batch = |(batch_idx, batch): (usize, &mut [[B; N]])| {
let row_offset = batch_idx * rows_per_batch;
for i in 0..rows_per_batch {
let row_idx = i + row_offset;
for j in 0..num_segs {
let v = &segments[j][row_idx];
batch[i * num_segs + j].copy_from_slice(v);
}
}
};
#[cfg(not(feature = "concurrent"))]
transpose_batch((0, &mut result));
#[cfg(feature = "concurrent")]
result
.par_chunks_mut(result_len / num_batches)
.enumerate()
.for_each(transpose_batch);
result
}
#[cfg(not(feature = "concurrent"))]
fn get_num_batches(_input_size: usize) -> usize {
1
}
#[cfg(feature = "concurrent")]
fn get_num_batches(input_size: usize) -> usize {
if input_size < 1024 {
return 1;
}
utils::rayon::current_num_threads().next_power_of_two() * 2
}