use alloc::vec::Vec;
use core::{iter::FusedIterator, slice};
use crypto::{ElementHasher, VectorCommitment};
use math::{fft, polynom, FieldElement};
#[cfg(feature = "concurrent")]
use utils::iterators::*;
use utils::{batch_iter_mut, iter, iter_mut, uninit_vector};
use crate::StarkDomain;
#[derive(Debug, Clone)]
pub struct ColMatrix<E: FieldElement> {
columns: Vec<Vec<E>>,
}
impl<E: FieldElement> ColMatrix<E> {
pub fn new(columns: Vec<Vec<E>>) -> Self {
assert!(!columns.is_empty(), "a matrix must contain at least one column");
let num_rows = columns[0].len();
assert!(num_rows > 1, "number of rows in a matrix must be greater than one");
assert!(num_rows.is_power_of_two(), "number of rows in a matrix must be a power of 2");
for column in columns.iter().skip(1) {
assert_eq!(column.len(), num_rows, "all matrix columns must have the same length");
}
Self { columns }
}
pub fn num_cols(&self) -> usize {
self.columns.len()
}
pub fn num_base_cols(&self) -> usize {
self.num_cols() * E::EXTENSION_DEGREE
}
pub fn num_rows(&self) -> usize {
self.columns[0].len()
}
pub fn get(&self, col_idx: usize, row_idx: usize) -> E {
self.columns[col_idx][row_idx]
}
pub fn get_base_element(&self, base_col_idx: usize, row_idx: usize) -> E::BaseField {
let (col_idx, elem_idx) =
(base_col_idx / E::EXTENSION_DEGREE, base_col_idx % E::EXTENSION_DEGREE);
self.columns[col_idx][row_idx].base_element(elem_idx)
}
pub fn set(&mut self, col_idx: usize, row_idx: usize, value: E) {
self.columns[col_idx][row_idx] = value;
}
pub fn get_column(&self, col_idx: usize) -> &[E] {
&self.columns[col_idx]
}
pub fn get_column_mut(&mut self, col_idx: usize) -> &mut [E] {
&mut self.columns[col_idx]
}
pub fn read_row_into(&self, row_idx: usize, row: &mut [E]) {
for (column, value) in self.columns.iter().zip(row.iter_mut()) {
*value = column[row_idx];
}
}
pub fn update_row(&mut self, row_idx: usize, row: &[E]) {
for (column, &value) in self.columns.iter_mut().zip(row) {
column[row_idx] = value;
}
}
pub fn merge_column(&mut self, column: Vec<E>) {
if let Some(first_column) = self.columns.first() {
assert_eq!(first_column.len(), column.len());
}
self.columns.push(column);
}
pub fn remove_column(&mut self, index: usize) -> Vec<E> {
assert!(index < self.num_cols(), "column index out of range");
self.columns.remove(index)
}
pub fn columns(&self) -> ColumnIter<'_, E> {
ColumnIter::new(self)
}
pub fn columns_mut(&mut self) -> ColumnIterMut<'_, E> {
ColumnIterMut::new(self)
}
pub fn interpolate_columns(&self) -> Self {
let inv_twiddles = fft::get_inv_twiddles::<E::BaseField>(self.num_rows());
let columns = iter!(self.columns)
.map(|evaluations| {
let mut column = evaluations.clone();
fft::interpolate_poly(&mut column, &inv_twiddles);
column
})
.collect();
Self { columns }
}
pub fn interpolate_columns_into(mut self) -> Self {
let inv_twiddles = fft::get_inv_twiddles::<E::BaseField>(self.num_rows());
iter_mut!(self.columns).for_each(|column| fft::interpolate_poly(column, &inv_twiddles));
self
}
pub fn evaluate_columns_over(&self, domain: &StarkDomain<E::BaseField>) -> Self {
let columns = iter!(self.columns)
.map(|poly| {
fft::evaluate_poly_with_offset(
poly,
domain.trace_twiddles(),
domain.offset(),
domain.trace_to_lde_blowup(),
)
})
.collect();
Self { columns }
}
pub fn evaluate_columns_at<F>(&self, x: F) -> Vec<F>
where
F: FieldElement + From<E>,
{
iter!(self.columns).map(|p| polynom::eval(p, x)).collect()
}
pub fn commit_to_rows<H, V>(&self) -> V
where
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
let mut row_hashes = unsafe { uninit_vector::<H::Digest>(self.num_rows()) };
batch_iter_mut!(
&mut row_hashes,
128, |batch: &mut [H::Digest], batch_offset: usize| {
let mut row_buf = vec![E::ZERO; self.num_cols()];
for (i, row_hash) in batch.iter_mut().enumerate() {
self.read_row_into(i + batch_offset, &mut row_buf);
*row_hash = H::hash_elements(&row_buf);
}
}
);
V::new(row_hashes).expect("failed to construct trace vector commitment")
}
pub fn into_columns(self) -> Vec<Vec<E>> {
self.columns
}
}
pub struct ColumnIter<'a, E: FieldElement> {
matrix: Option<&'a ColMatrix<E>>,
cursor: usize,
}
impl<'a, E: FieldElement> ColumnIter<'a, E> {
pub fn new(matrix: &'a ColMatrix<E>) -> Self {
Self { matrix: Some(matrix), cursor: 0 }
}
pub fn empty() -> Self {
Self { matrix: None, cursor: 0 }
}
}
impl<'a, E: FieldElement> Iterator for ColumnIter<'a, E> {
type Item = &'a [E];
fn next(&mut self) -> Option<Self::Item> {
match self.matrix {
Some(matrix) => match matrix.num_cols() - self.cursor {
0 => None,
_ => {
let column = matrix.get_column(self.cursor);
self.cursor += 1;
Some(column)
},
},
None => None,
}
}
}
impl<E: FieldElement> ExactSizeIterator for ColumnIter<'_, E> {
fn len(&self) -> usize {
self.matrix.map(|matrix| matrix.num_cols()).unwrap_or_default()
}
}
impl<E: FieldElement> FusedIterator for ColumnIter<'_, E> {}
impl<E: FieldElement> Default for ColumnIter<'_, E> {
fn default() -> Self {
Self::empty()
}
}
pub struct ColumnIterMut<'a, E: FieldElement> {
matrix: &'a mut ColMatrix<E>,
cursor: usize,
}
impl<'a, E: FieldElement> ColumnIterMut<'a, E> {
pub fn new(matrix: &'a mut ColMatrix<E>) -> Self {
Self { matrix, cursor: 0 }
}
}
impl<'a, E: FieldElement> Iterator for ColumnIterMut<'a, E> {
type Item = &'a mut [E];
fn next(&mut self) -> Option<Self::Item> {
match self.matrix.num_cols() - self.cursor {
0 => None,
_ => {
let column = self.matrix.get_column_mut(self.cursor);
self.cursor += 1;
let p = column.as_ptr();
let len = column.len();
Some(unsafe { slice::from_raw_parts_mut(p as *mut E, len) })
},
}
}
}
impl<E: FieldElement> ExactSizeIterator for ColumnIterMut<'_, E> {
fn len(&self) -> usize {
self.matrix.num_cols()
}
}
impl<E: FieldElement> FusedIterator for ColumnIterMut<'_, E> {}