use super::StarkDomain;
use core::{iter::FusedIterator, slice};
use crypto::{ElementHasher, MerkleTree};
use math::{fft, polynom, FieldElement};
use utils::{batch_iter_mut, collections::Vec, iter, iter_mut, uninit_vector};
#[cfg(feature = "concurrent")]
use utils::iterators::*;
#[derive(Debug, Clone)]
pub struct Matrix<E: FieldElement> {
columns: Vec<Vec<E>>,
}
impl<E: FieldElement> Matrix<E> {
pub fn new(columns: Vec<Vec<E>>) -> Self {
assert!(
!columns.is_empty(),
"a matrix must contain at least one column"
);
let num_rows = columns[0].len();
assert!(
num_rows > 1,
"number of rows in a matrix must be greater than one"
);
assert!(
num_rows.is_power_of_two(),
"number of rows in a matrix must be a power of 2"
);
for column in columns.iter().skip(1) {
assert_eq!(
column.len(),
num_rows,
"all matrix columns must have the same length"
);
}
Self { columns }
}
pub fn num_cols(&self) -> usize {
self.columns.len()
}
pub fn num_rows(&self) -> usize {
self.columns[0].len()
}
pub fn get(&self, col_idx: usize, row_idx: usize) -> E {
self.columns[col_idx][row_idx]
}
pub fn set(&mut self, col_idx: usize, row_idx: usize, value: E) {
self.columns[col_idx][row_idx] = value;
}
pub fn get_column(&self, col_idx: usize) -> &[E] {
&self.columns[col_idx]
}
pub fn get_column_mut(&mut self, col_idx: usize) -> &mut [E] {
&mut self.columns[col_idx]
}
pub fn read_row_into(&self, row_idx: usize, row: &mut [E]) {
for (column, value) in self.columns.iter().zip(row.iter_mut()) {
*value = column[row_idx];
}
}
pub fn update_row(&mut self, row_idx: usize, row: &[E]) {
for (column, &value) in self.columns.iter_mut().zip(row) {
column[row_idx] = value;
}
}
pub fn columns(&self) -> ColumnIter<E> {
ColumnIter::new(self)
}
pub fn columns_mut(&mut self) -> ColumnIterMut<E> {
ColumnIterMut::new(self)
}
pub fn interpolate_columns(&self) -> Self {
let inv_twiddles = fft::get_inv_twiddles::<E::BaseField>(self.num_rows());
let columns = iter!(self.columns)
.map(|evaluations| {
let mut column = evaluations.clone();
fft::interpolate_poly(&mut column, &inv_twiddles);
column
})
.collect();
Self { columns }
}
pub fn interpolate_columns_into(mut self) -> Self {
let inv_twiddles = fft::get_inv_twiddles::<E::BaseField>(self.num_rows());
iter_mut!(self.columns).for_each(|column| fft::interpolate_poly(column, &inv_twiddles));
self
}
pub fn evaluate_columns_over(&self, domain: &StarkDomain<E::BaseField>) -> Self {
let columns = iter!(self.columns)
.map(|poly| {
fft::evaluate_poly_with_offset(
poly,
domain.trace_twiddles(),
domain.offset(),
domain.trace_to_lde_blowup(),
)
})
.collect();
Self { columns }
}
pub fn evaluate_columns_at<F>(&self, x: F) -> Vec<F>
where
F: FieldElement + From<E>,
{
iter!(self.columns).map(|p| polynom::eval(p, x)).collect()
}
pub fn commit_to_rows<H>(&self) -> MerkleTree<H>
where
H: ElementHasher<BaseField = E::BaseField>,
{
let mut row_hashes = unsafe { uninit_vector::<H::Digest>(self.num_rows()) };
batch_iter_mut!(
&mut row_hashes,
128, |batch: &mut [H::Digest], batch_offset: usize| {
let mut row_buf = vec![E::ZERO; self.num_cols()];
for (i, row_hash) in batch.iter_mut().enumerate() {
self.read_row_into(i + batch_offset, &mut row_buf);
*row_hash = H::hash_elements(&row_buf);
}
}
);
MerkleTree::new(row_hashes).expect("failed to construct trace Merkle tree")
}
pub fn into_columns(self) -> Vec<Vec<E>> {
self.columns
}
}
pub struct ColumnIter<'a, E: FieldElement> {
matrix: &'a Matrix<E>,
cursor: usize,
}
impl<'a, E: FieldElement> ColumnIter<'a, E> {
pub fn new(matrix: &'a Matrix<E>) -> Self {
Self { matrix, cursor: 0 }
}
}
impl<'a, E: FieldElement> Iterator for ColumnIter<'a, E> {
type Item = &'a [E];
fn next(&mut self) -> Option<Self::Item> {
match self.matrix.num_cols() - self.cursor {
0 => None,
_ => {
let column = self.matrix.get_column(self.cursor);
self.cursor += 1;
Some(column)
}
}
}
}
impl<'a, E: FieldElement> ExactSizeIterator for ColumnIter<'a, E> {
fn len(&self) -> usize {
self.matrix.num_cols()
}
}
impl<'a, E: FieldElement> FusedIterator for ColumnIter<'a, E> {}
pub struct ColumnIterMut<'a, E: FieldElement> {
matrix: &'a mut Matrix<E>,
cursor: usize,
}
impl<'a, E: FieldElement> ColumnIterMut<'a, E> {
pub fn new(matrix: &'a mut Matrix<E>) -> Self {
Self { matrix, cursor: 0 }
}
}
impl<'a, E: FieldElement> Iterator for ColumnIterMut<'a, E> {
type Item = &'a mut [E];
fn next(&mut self) -> Option<Self::Item> {
match self.matrix.num_cols() - self.cursor {
0 => None,
_ => {
let column = self.matrix.get_column_mut(self.cursor);
self.cursor += 1;
let p = column.as_ptr();
let len = column.len();
Some(unsafe { slice::from_raw_parts_mut(p as *mut E, len) })
}
}
}
}
impl<'a, E: FieldElement> ExactSizeIterator for ColumnIterMut<'a, E> {
fn len(&self) -> usize {
self.matrix.num_cols()
}
}
impl<'a, E: FieldElement> FusedIterator for ColumnIterMut<'a, E> {}
pub struct MultiColumnIter<'a, E: FieldElement> {
matrixes: &'a [Matrix<E>],
m_cursor: usize,
c_cursor: usize,
}
impl<'a, E: FieldElement> MultiColumnIter<'a, E> {
pub fn new(matrixes: &'a [Matrix<E>]) -> Self {
if !matrixes.is_empty() {
let num_rows = matrixes[0].num_rows();
for matrix in matrixes.iter().skip(1) {
assert_eq!(
matrix.num_rows(),
num_rows,
"all matrixes must have the same number of rows"
);
}
}
Self {
matrixes,
m_cursor: 0,
c_cursor: 0,
}
}
}
impl<'a, E: FieldElement> Iterator for MultiColumnIter<'a, E> {
type Item = &'a [E];
fn next(&mut self) -> Option<Self::Item> {
if self.matrixes.is_empty() {
return None;
}
let matrix = &self.matrixes[self.m_cursor];
match matrix.num_cols() - self.c_cursor {
0 => None,
_ => {
let column = matrix.get_column(self.c_cursor);
self.c_cursor += 1;
if self.c_cursor == matrix.num_cols() && self.m_cursor < self.matrixes.len() - 1 {
self.m_cursor += 1;
self.c_cursor = 0;
}
Some(column)
}
}
}
}
impl<'a, E: FieldElement> ExactSizeIterator for MultiColumnIter<'a, E> {
fn len(&self) -> usize {
self.matrixes.iter().fold(0, |s, m| s + m.num_cols())
}
}
impl<'a, E: FieldElement> FusedIterator for MultiColumnIter<'a, E> {}